From 615418584fc9fcb65e4992ce3e11e7facb3ee81a Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 5 Mar 2026 12:12:52 +0800 Subject: [PATCH 1/7] InsertSync: support ping/pong multibuffer - Add pto.multi_buffer=2 attr plumbing into PlanMemory (alloc_tile -> memref.alloc). - Detect ping/pong via planned address overlap-matrix and emit dynamic event-id set/wait. - Add EnableMultiBuffer pass to materialize loop-local ping/pong selection. - Add Sync sample + runop guard; fix PTOViewToMemref typed accessor crash for bitcast/treshape. --- include/PTO/IR/PTOOps.td | 26 ++ .../InsertSync/MultiBufferSelector.h | 22 ++ .../PTO/Transforms/InsertSync/SyncCodegen.h | 3 - include/PTO/Transforms/Passes.h | 1 + include/PTO/Transforms/Passes.td | 17 ++ lib/PTO/Transforms/CMakeLists.txt | 2 + lib/PTO/Transforms/EnableMultiBuffer.cpp | 224 ++++++++++++++++++ .../InsertSync/InsertSyncAnalysis.cpp | 70 +++++- .../InsertSync/MemoryDependentAnalyzer.cpp | 36 +-- .../InsertSync/MultiBufferSelector.cpp | 58 +++++ .../Transforms/InsertSync/PTOIRTranslator.cpp | 74 +++++- .../Transforms/InsertSync/PTOInsertSync.cpp | 4 +- lib/PTO/Transforms/InsertSync/SyncCodegen.cpp | 95 +++++--- lib/PTO/Transforms/PTOPlanMemory.cpp | 52 +++- lib/PTO/Transforms/PTOPlanMemory.h | 3 +- lib/PTO/Transforms/PTOToEmitC.cpp | 77 ++++++ lib/PTO/Transforms/PTOViewToMemref.cpp | 16 +- .../test_inject_sync_multibuf_pingpong.py | 70 ++++++ test/samples/runop.sh | 22 ++ tools/ptoas/ptoas.cpp | 3 + 20 files changed, 782 insertions(+), 93 deletions(-) create mode 100644 include/PTO/Transforms/InsertSync/MultiBufferSelector.h create mode 100644 lib/PTO/Transforms/EnableMultiBuffer.cpp create mode 100644 lib/PTO/Transforms/InsertSync/MultiBufferSelector.cpp create mode 100644 test/samples/Sync/test_inject_sync_multibuf_pingpong.py diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index c967a75e..43ec6c8d 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -1165,6 +1165,32 @@ def WaitFlagOp : PTO_Op<"wait_flag"> { }]; } +def SetFlagDynOp : PTO_Op<"set_flag_dyn"> { + let summary = "Set synchronization flag between pipes (dynamic event id)"; + let arguments = (ins + PTO_PipeAttr:$src_pipe, + PTO_PipeAttr:$dst_pipe, + Index:$event_id + ); + let results = (outs); + let assemblyFormat = [{ + `[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict + }]; +} + +def WaitFlagDynOp : PTO_Op<"wait_flag_dyn"> { + let summary = "Wait for synchronization flag (dynamic event id)"; + let arguments = (ins + PTO_PipeAttr:$src_pipe, + PTO_PipeAttr:$dst_pipe, + Index:$event_id + ); + let results = (outs); + let assemblyFormat = [{ + `[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict + }]; +} + //===----------------------------------------------------------------------===// // Buffer-ID Synchronization (A5) //===----------------------------------------------------------------------===// diff --git a/include/PTO/Transforms/InsertSync/MultiBufferSelector.h b/include/PTO/Transforms/InsertSync/MultiBufferSelector.h new file mode 100644 index 00000000..58669de7 --- /dev/null +++ b/include/PTO/Transforms/InsertSync/MultiBufferSelector.h @@ -0,0 +1,22 @@ +#ifndef MLIR_DIALECT_PTO_TRANSFORMS_INSERTSYNC_MULTIBUFFERSELECTOR_H +#define MLIR_DIALECT_PTO_TRANSFORMS_INSERTSYNC_MULTIBUFFERSELECTOR_H + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace pto { + +/// Build a boolean `cond` that flips between even/odd iterations across a loop +/// nest. +/// +/// - The condition is inserted at the beginning of `baseLoop`'s body. +/// - The computed parity is based on a flattened linear index across `baseLoop` +/// and all its parent `scf.for` loops, supporting non-unit steps. +/// - Returns a null Value if `baseLoop` is invalid. +Value buildLoopNestParityCond(IRRewriter &rewriter, scf::ForOp baseLoop); + +} // namespace pto +} // namespace mlir + +#endif // MLIR_DIALECT_PTO_TRANSFORMS_INSERTSYNC_MULTIBUFFERSELECTOR_H diff --git a/include/PTO/Transforms/InsertSync/SyncCodegen.h b/include/PTO/Transforms/InsertSync/SyncCodegen.h index 84d3d1b9..a599b40f 100644 --- a/include/PTO/Transforms/InsertSync/SyncCodegen.h +++ b/include/PTO/Transforms/InsertSync/SyncCodegen.h @@ -57,9 +57,6 @@ class SyncCodegen { Value GetBufferSelected(IRRewriter &rewriter, Operation *op, SyncOperation *sync); - // 生成嵌套循环的计数器 (用于多缓冲切换) - Value createNestedIndexModular(IRRewriter &rewriter, Operation *defineOp); - private: SyncIRs &syncIR_; func::FuncOp func_; diff --git a/include/PTO/Transforms/Passes.h b/include/PTO/Transforms/Passes.h index 3df9390d..5d39fbee 100644 --- a/include/PTO/Transforms/Passes.h +++ b/include/PTO/Transforms/Passes.h @@ -41,6 +41,7 @@ std::unique_ptr createLoweringSyncToPipePass(); // Creates a pass for ... std::unique_ptr createPTOInsertSyncPass(); +std::unique_ptr createPTOEnableMultiBufferPass(); // Default arch is A3 unless overridden by callers. std::unique_ptr createEmitPTOManualPass(); // Explicitly select target arch for codegen. diff --git a/include/PTO/Transforms/Passes.td b/include/PTO/Transforms/Passes.td index ab7f29df..038d115d 100644 --- a/include/PTO/Transforms/Passes.td +++ b/include/PTO/Transforms/Passes.td @@ -50,6 +50,23 @@ def PTOInsertSync : Pass<"pto-insert-sync", "func::FuncOp"> { ]; } +def PTOEnableMultiBuffer : Pass<"pto-enable-multibuffer", "func::FuncOp"> { + let summary = "Materialize multi-buffer (ping/pong) selection"; + let description = [{ + Rewrites `pto.pointer_cast(addrs=[ping,pong])` into a loop-local selector + that dynamically picks the active address. This enables planned ping/pong + buffers to take effect in the emitted C++. + }]; + + let constructor = "mlir::pto::createPTOEnableMultiBufferPass()"; + + let dependentDialects = [ + "mlir::pto::PTODialect", + "mlir::arith::ArithDialect", + "mlir::scf::SCFDialect" + ]; +} + def ConvertToPTOOp : Pass<"convert-to-pto-op"> { let summary = "Convert Ops from other dialects to PTO Ops"; let constructor = "mlir::pto::createConvertToPTOOpPass()"; diff --git a/lib/PTO/Transforms/CMakeLists.txt b/lib/PTO/Transforms/CMakeLists.txt index d9d013c9..219cd31d 100644 --- a/lib/PTO/Transforms/CMakeLists.txt +++ b/lib/PTO/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(PTOTransforms AllocToPointerCast.cpp InferPTOMemScope.cpp PTOPlanMemory.cpp + EnableMultiBuffer.cpp PTORemoveRedundantBarrier.cpp InferPTOLayout.cpp BufferizableOpInterfaceImpl.cpp @@ -22,6 +23,7 @@ add_mlir_dialect_library(PTOTransforms InsertSync/InsertSyncAnalysis.cpp InsertSync/MemoryDependentAnalyzer.cpp InsertSync/MoveSyncState.cpp + InsertSync/MultiBufferSelector.cpp InsertSync/RemoveRedundantSync.cpp InsertSync/SyncEventIdAllocation.cpp InsertSync/SyncCodegen.cpp diff --git a/lib/PTO/Transforms/EnableMultiBuffer.cpp b/lib/PTO/Transforms/EnableMultiBuffer.cpp new file mode 100644 index 00000000..8f9aabb7 --- /dev/null +++ b/lib/PTO/Transforms/EnableMultiBuffer.cpp @@ -0,0 +1,224 @@ +//===- EnableMultiBuffer.cpp - Materialize ping/pong buffer selection ------===// +// +// This pass rewrites `pto.pointer_cast(addrs=[ping, pong])` into: +// - two single-address pointer_cast ops hoisted outside a selected loop, and +// - a loop-local `arith.select` that chooses the active buffer based on the +// flattened loop iteration parity. +// +// The goal is to make multi-buffer planning observable in the emitted C++ +// (the default PointerCast lowering uses only the first address operand). +// +// Currently only double-buffer (2) is supported. +// +//===----------------------------------------------------------------------===// + +#include "PTO/IR/PTO.h" +#include "PTO/Transforms/InsertSync/MultiBufferSelector.h" +#include "PTO/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { +namespace pto { +namespace { + +#define GEN_PASS_DEF_PTOENABLEMULTIBUFFER +#include "PTO/Transforms/Passes.h.inc" + +static bool isGlobalMemRef(MemRefType ty) { + if (auto asAttr = + dyn_cast_or_null(ty.getMemorySpace())) { + auto as = asAttr.getAddressSpace(); + return (as == pto::AddressSpace::GM || as == pto::AddressSpace::Zero); + } + // Treat missing memory_space as GM. + return true; +} + +static bool getConstI64(Value v, int64_t &out) { + llvm::APInt ap; + if (!matchPattern(v, m_ConstantInt(&ap))) + return false; + out = ap.getSExtValue(); + return true; +} + +static bool isAncestorLoop(scf::ForOp ancestor, scf::ForOp loop) { + if (!ancestor || !loop) + return false; + Operation *cur = loop.getOperation(); + while (cur) { + if (cur == ancestor.getOperation()) + return true; + cur = cur->getParentOp(); + } + return false; +} + +static scf::ForOp lowestCommonAncestorLoop(ArrayRef loops) { + if (loops.empty()) + return {}; + scf::ForOp lca = loops.front(); + for (scf::ForOp loop : loops.drop_front()) { + while (lca && !isAncestorLoop(lca, loop)) + lca = lca->getParentOfType(); + if (!lca) + return {}; + } + return lca; +} + +static bool isInLoopBody(Operation *op, scf::ForOp loop) { + if (!op || !loop) + return false; + Operation *cur = op; + while (cur) { + if (cur == loop.getOperation()) + return op != cur; + cur = cur->getParentOp(); + } + return false; +} + +struct PTOEnableMultiBufferPass + : public impl::PTOEnableMultiBufferBase { + void runOnOperation() override { + func::FuncOp func = getOperation(); + IRRewriter rewriter(&getContext()); + DominanceInfo dom(func); + + DenseMap loop2Cond; + + SmallVector casts; + func.walk([&](pto::PointerCastOp op) { + if (op.getAddrs().size() > 1) + casts.push_back(op); + }); + + for (pto::PointerCastOp op : casts) { + auto mrTy = dyn_cast(op.getType()); + if (!mrTy) + continue; + if (isGlobalMemRef(mrTy)) + continue; + + auto addrs = op.getAddrs(); + if (addrs.size() != 2) { + op.emitError("only double-buffer pointer_cast (2 addresses) is supported"); + return signalPassFailure(); + } + + int64_t addr0 = 0, addr1 = 0; + if (!getConstI64(addrs[0], addr0) || !getConstI64(addrs[1], addr1) || + addr0 < 0 || addr1 < 0) { + op.emitError("expected constant non-negative i64 addrs for double-buffer pointer_cast"); + return signalPassFailure(); + } + + // Collect the enclosing loop for each use site. The resulting LCA is the + // loop in which we materialize the ping/pong selector. + SmallVector useLoops; + llvm::DenseSet seen; + for (OpOperand &use : op.getResult().getUses()) { + Operation *owner = use.getOwner(); + if (!owner) + continue; + scf::ForOp enclosing = owner->getParentOfType(); + if (!enclosing) + continue; + if (seen.insert(enclosing.getOperation()).second) + useLoops.push_back(enclosing); + } + + scf::ForOp baseLoop = lowestCommonAncestorLoop(useLoops); + if (!baseLoop) { + // No loop uses: keep behavior deterministic by dropping the extra addr. + rewriter.setInsertionPoint(op); + Attribute config = op.getConfig() ? Attribute(*op.getConfig()) : Attribute(); + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + auto collapsed = rewriter.create( + op.getLoc(), op.getType(), ValueRange{addrs[0]}, + vRow ? vRow : Value(), vCol ? vCol : Value(), config); + rewriter.replaceOp(op, collapsed.getResult()); + continue; + } + + // If the original pointer_cast is used as an operand of the selected base + // loop op, we cannot replace that use with a value defined inside the + // loop. Treat this as unsupported to avoid miscompilation. + for (OpOperand &use : op.getResult().getUses()) { + if (use.getOwner() == baseLoop.getOperation()) { + op.emitError("unsupported: multi-buffer pointer_cast used as an operand of the base scf.for"); + return signalPassFailure(); + } + } + + // Hoist two single-address pointer_cast ops just before the base loop. + rewriter.setInsertionPoint(baseLoop); + Value c0 = rewriter.create(op.getLoc(), addr0, 64); + Value c1 = rewriter.create(op.getLoc(), addr1, 64); + Attribute config = op.getConfig() ? Attribute(*op.getConfig()) : Attribute(); + Value vRow = op.getValidRow(); + Value vCol = op.getValidCol(); + + if ((vRow && !dom.dominates(vRow, baseLoop.getOperation())) || + (vCol && !dom.dominates(vCol, baseLoop.getOperation()))) { + op.emitError("unsupported: valid_row/valid_col must dominate the selected loop for hoisting"); + return signalPassFailure(); + } + + auto ptr0 = rewriter.create( + op.getLoc(), op.getType(), ValueRange{c0}, vRow ? vRow : Value(), + vCol ? vCol : Value(), config); + auto ptr1 = rewriter.create( + op.getLoc(), op.getType(), ValueRange{c1}, vRow ? vRow : Value(), + vCol ? vCol : Value(), config); + + // Build (or reuse) loop-parity condition and select the active buffer. + Value cond; + auto it = loop2Cond.find(baseLoop.getOperation()); + if (it != loop2Cond.end()) { + cond = it->second; + } else { + cond = buildLoopNestParityCond(rewriter, baseLoop); + if (!cond) { + op.emitError("failed to build loop-nest parity condition for multi-buffer selection"); + return signalPassFailure(); + } + loop2Cond[baseLoop.getOperation()] = cond; + } + + rewriter.setInsertionPointAfter(cond.getDefiningOp()); + Value selected = rewriter.create( + op.getLoc(), cond, ptr1.getResult(), ptr0.getResult()); + + // Replace uses that are inside the base loop body (including nested ops). + SmallVector toReplace; + for (OpOperand &use : op.getResult().getUses()) { + Operation *owner = use.getOwner(); + if (owner && isInLoopBody(owner, baseLoop)) + toReplace.push_back(&use); + } + for (OpOperand *use : toReplace) + use->set(selected); + + if (op.getResult().use_empty()) + op.erase(); + } + } +}; + +} // namespace + +std::unique_ptr createPTOEnableMultiBufferPass() { + return std::make_unique(); +} + +} // namespace pto +} // namespace mlir diff --git a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp index a33bf889..5910782a 100644 --- a/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp +++ b/lib/PTO/Transforms/InsertSync/InsertSyncAnalysis.cpp @@ -352,6 +352,19 @@ void InsertSyncAnalysis::InsertSyncOperation( // Back-edge dependencies may require multi-buffer event IDs. if (forEndIndex.has_value()) { int eventIdNum = GetEventIdNum(depBaseMemInfosVec); + + // Multi-buffer selection relies on a well-defined scf.for loop to compute + // the ping/pong slot. If the dependency is carried by a non-for loop, + // fall back to single-buffer synchronization. + if (eventIdNum > 1) { + auto *loopEndElem = + dyn_cast(syncIR_[forEndIndex.value()].get()); + auto loopOp = loopEndElem ? dyn_cast_or_null(loopEndElem->elementOp) + : scf::ForOp(); + if (!loopOp) { + eventIdNum = 1; + } + } setOp->eventIdNum = eventIdNum; waitOp->eventIdNum = eventIdNum; } @@ -510,16 +523,57 @@ SmallVector InsertSyncAnalysis::GetMemInfoBuffers( int InsertSyncAnalysis::GetEventIdNum( const DepBaseMemInfoPairVec &depBaseMemInfosVec) { + if (depBaseMemInfosVec.empty()) + return 1; + + auto isOverlap = [](const BaseMemInfo *a, const BaseMemInfo *b, int i, + int j) -> bool { + uint64_t aStart = a->baseAddresses[static_cast(i)]; + uint64_t bStart = b->baseAddresses[static_cast(j)]; + uint64_t aEnd = aStart + a->allocateSize; + uint64_t bEnd = bStart + b->allocateSize; + uint64_t maxStart = std::max(aStart, bStart); + uint64_t minEnd = std::min(aEnd, bEnd); + return maxStart < minEnd; + }; + for (const auto &pair : depBaseMemInfosVec) { - bool isLocalA = - pair.first && (pair.first->scope == pto::AddressSpace::MAT || - pair.first->scope == pto::AddressSpace::VEC); - bool isLocalB = - pair.second && (pair.second->scope == pto::AddressSpace::MAT || - pair.second->scope == pto::AddressSpace::VEC); - if (isLocalA || isLocalB) return 2; + const BaseMemInfo *a = pair.first; + const BaseMemInfo *b = pair.second; + if (!a || !b) { + return 1; + } + if (a->scope == pto::AddressSpace::GM || b->scope == pto::AddressSpace::GM) { + return 1; + } + + const int aSz = static_cast(a->baseAddresses.size()); + const int bSz = static_cast(b->baseAddresses.size()); + if (aSz != bSz || aSz <= 1) { + return 1; + } + + // Currently only support double-buffer (ping/pong). + if (aSz != 2) { + return 1; + } + + // Require known sizes to prove non-overlap across slots. + if (a->allocateSize == 0 || b->allocateSize == 0) { + return 1; + } + + for (int i = 0; i < aSz; i++) { + for (int j = 0; j < bSz; j++) { + bool overlap = isOverlap(a, b, i, j); + if ((i == j && !overlap) || (i != j && overlap)) { + return 1; + } + } + } } - return 1; + + return 2; } bool InsertSyncAnalysis::IsGMHazard( diff --git a/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp b/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp index 50135748..339ad199 100644 --- a/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp +++ b/lib/PTO/Transforms/InsertSync/MemoryDependentAnalyzer.cpp @@ -142,32 +142,20 @@ bool MemoryDependentAnalyzer::MemAlias(const BaseMemInfo *a, } // 2. Local Memory (UB/L1) - - if (a->rootBuffer == b->rootBuffer) { - if (a->baseAddresses.empty() || b->baseAddresses.empty()) return true; - return isBufferAddressRangeOverlap(a, b); - } - - // 2.2 深层比较:穿透 View - Value realRootA = GetRealRoot(a->rootBuffer); - Value realRootB = GetRealRoot(b->rootBuffer); - - if (isTraceEnabled()) { - llvm::errs() << " [Deep Check] Surface Roots differ. Digging deeper...\n"; - printValueDebug(" Real Root A", realRootA); - printValueDebug(" Real Root B", realRootB); + // + // After PlanMemory, distinct SSA buffers may legally alias the same physical + // storage. Root-buffer identity is not sufficient: rely on planned addresses. + if (a->baseAddresses.empty() || b->baseAddresses.empty()) { + if (isTraceEnabled()) + llvm::errs() << " -> Unknown baseAddresses. Conservative overlap.\n"; + return true; } - - if (realRootA == realRootB && realRootA != nullptr) { - if (isTraceEnabled()) - llvm::errs() << " -> MATCH! Real roots are the same.\n"; - return true; - } else { - if (isTraceEnabled()) - llvm::errs() << " -> Mismatch. Real roots differ.\n"; + if (a->allocateSize == 0 || b->allocateSize == 0) { + if (isTraceEnabled()) + llvm::errs() << " -> Unknown allocateSize. Conservative overlap.\n"; + return true; } - - return false; + return isBufferAddressRangeOverlap(a, b); } bool MemoryDependentAnalyzer::isGMBufferOverlap(const BaseMemInfo *a, diff --git a/lib/PTO/Transforms/InsertSync/MultiBufferSelector.cpp b/lib/PTO/Transforms/InsertSync/MultiBufferSelector.cpp new file mode 100644 index 00000000..31c1956a --- /dev/null +++ b/lib/PTO/Transforms/InsertSync/MultiBufferSelector.cpp @@ -0,0 +1,58 @@ +#include "PTO/Transforms/InsertSync/MultiBufferSelector.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" + +using namespace mlir; + +namespace mlir { +namespace pto { + +Value buildLoopNestParityCond(IRRewriter &rewriter, scf::ForOp baseLoop) { + if (!baseLoop) + return nullptr; + + Location loc = baseLoop.getLoc(); + + // Insert at the beginning of the base loop body so it dominates all uses + // within the loop nest. + rewriter.setInsertionPointToStart(baseLoop.getBody()); + + Value idx = rewriter.create(loc, 0); + Value nElems = rewriter.create(loc, 1); + Value one = rewriter.create(loc, 1); + + // Collect loop nest from inner to outer (baseLoop, parent, ...). + SmallVector loops; + for (scf::ForOp cur = baseLoop; cur; cur = cur->getParentOfType()) + loops.push_back(cur); + + for (scf::ForOp loop : loops) { + Value iv = loop.getInductionVar(); + Value lb = loop.getLowerBound(); + Value ub = loop.getUpperBound(); + Value step = loop.getStep(); + + // iter = (iv - lb) / step + Value iter = rewriter.create( + loc, rewriter.create(loc, iv, lb), step); + idx = rewriter.create( + loc, idx, rewriter.create(loc, iter, nElems)); + + // tripCount = ceilDiv(ub - lb, step) = (ub - lb + step - 1) / step + Value span = rewriter.create(loc, ub, lb); + Value stepMinusOne = rewriter.create(loc, step, one); + Value num = rewriter.create(loc, span, stepMinusOne); + Value tripCount = rewriter.create(loc, num, step); + nElems = rewriter.create(loc, nElems, tripCount); + } + + Value two = rewriter.create(loc, 2); + Value mod = rewriter.create(loc, idx, two); + Value zero = rewriter.create(loc, 0); + return rewriter.create(loc, arith::CmpIPredicate::ne, mod, + zero); +} + +} // namespace pto +} // namespace mlir + diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index 33aec28b..dfbc127e 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -258,14 +258,61 @@ LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) if (op.getAddrs().empty()) { return op.emitError("PointerCast must have at least one address operand"); } - Value rootSrc = op.getAddrs().front(); - + SmallVector baseAddresses; + baseAddresses.reserve(op.getAddrs().size()); + for (Value addr : op.getAddrs()) { + llvm::APInt apIntValue; + if (!matchPattern(addr, m_ConstantInt(&apIntValue))) { + // Variable address: be conservative and treat as unknown overlap. + baseAddresses.clear(); + break; + } + int64_t c = apIntValue.getSExtValue(); + if (c < 0) { + // Unexpected negative planned address: drop address info to stay + // conservative in dependency analysis. + baseAddresses.clear(); + break; + } + baseAddresses.push_back(static_cast(c)); + } + uint64_t sizeInBytes = 0; if (memRefType.hasStaticShape()) { - int64_t elemSize = memRefType.getElementType().getIntOrFloatBitWidth() / 8; - int64_t numElements = 1; - for (auto dim : memRefType.getShape()) numElements *= dim; - sizeInBytes = numElements * elemSize; + int64_t bitWidth = + memRefType.getElementType().getIntOrFloatBitWidth(); + uint64_t elemBytes = static_cast((bitWidth + 7) / 8); + if (elemBytes == 0) + elemBytes = 1; + + // Prefer stride-based size computation to account for padded/fractal layouts. + SmallVector strides; + int64_t offset = ShapedType::kDynamic; + if (succeeded(getStridesAndOffset(memRefType, strides, offset)) && + offset != ShapedType::kDynamic && + llvm::all_of(strides, [](int64_t s) { return s != ShapedType::kDynamic; }) && + offset >= 0) { + uint64_t maxIndex = static_cast(offset); + auto shape = memRefType.getShape(); + bool invalid = false; + for (size_t i = 0; i < shape.size(); ++i) { + int64_t dim = shape[i]; + if (dim <= 0) { + invalid = true; + break; + } + uint64_t stride = static_cast(strides[i]); + maxIndex += static_cast(dim - 1) * stride; + } + if (!invalid && !shape.empty()) { + sizeInBytes = (maxIndex + 1) * elemBytes; + } + } else { + uint64_t numElements = 1; + for (auto dim : memRefType.getShape()) + numElements *= static_cast(dim); + sizeInBytes = numElements * elemBytes; + } } pto::AddressSpace space = pto::AddressSpace::GM; @@ -277,9 +324,9 @@ LogicalResult PTOIRTranslator::UpdatePointerCastOpMemInfo(pto::PointerCastOp op) auto newMemInfo = std::make_unique( res, - rootSrc, + res, space, - SmallVector{0}, + std::move(baseAddresses), sizeInBytes ); @@ -512,9 +559,14 @@ void PTOIRTranslator::UpdateAliasBufferInfo(Value result, Value source) { auto newInfo = parentInfo->clone(result); if (!newInfo->baseAddresses.empty()) { - newInfo->baseAddresses[0] += deltaOffset; - } else { - newInfo->baseAddresses.push_back(deltaOffset); + if (deltaOffset < 0) { + // Negative offsets are unexpected for buffer views in this pipeline. + // Drop address information to stay conservative in dependency analysis. + newInfo->baseAddresses.clear(); + } else { + for (auto &addr : newInfo->baseAddresses) + addr += static_cast(deltaOffset); + } } if (newSize > 0) { diff --git a/lib/PTO/Transforms/InsertSync/PTOInsertSync.cpp b/lib/PTO/Transforms/InsertSync/PTOInsertSync.cpp index dd30bf8c..c402cd54 100644 --- a/lib/PTO/Transforms/InsertSync/PTOInsertSync.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOInsertSync.cpp @@ -57,8 +57,8 @@ struct PTOInsertSyncPass : public mlir::pto::impl::PTOInsertSyncBase(op)) { + if (isa(op)) { hasExplicitSync = true; return WalkResult::interrupt(); } diff --git a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp index 4e8c5c28..33d9f20b 100644 --- a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp +++ b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp @@ -1,5 +1,6 @@ #include "PTO/Transforms/InsertSync/SyncCodegen.h" #include "PTO/IR/PTO.h" +#include "PTO/Transforms/InsertSync/MultiBufferSelector.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "llvm/ADT/STLExtras.h" @@ -279,10 +280,8 @@ void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter, Operation *op, SyncOperation *sync, bool beforeInsert) { - // 注意:GetBufferSelected 可能需要在插入 Set/Wait 之前调用,以确保 SSA 顺序 - // 但这里只是获取 Value,不影响 InsertionPoint 的设定 - Value bufferSelected = GetBufferSelected(rewriter, op, sync); - (void)bufferSelected; + // Multi-buffer needs a dynamic selector to choose between event IDs. + Value selectedEventId = GetBufferSelected(rewriter, op, sync); // [Fix] Terminator 强制前置插入 if (beforeInsert || op->hasTrait()) { @@ -293,20 +292,22 @@ void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter, auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe()); auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe()); - auto eventId = getEventAttr(rewriter, sync->eventIds[0]); // 注意:MultiBuffer可能需要特殊处理Attr - - // 这里假设 SetFlagOp/WaitFlagOp 支持动态 Value 作为 EventID,或者您有特殊的 Op - // 如果 PTO 定义只支持 Attribute,那么上面的 GetBufferSelected 逻辑需要配合修改 Op 定义 - // 假设目前的 Op 定义如下: - if (sync->isSyncWaitType()) { - // 假设 WaitFlagOp 有支持 Value eventId 的重载或变体 - // 如果没有,这行代码可能需要调整。但在您之前的 Double Buffer 测试中,看起来它是工作的? - // 或者您是否使用了 UpdateFlagOp (带 Value)? - // 这里保持原样,只修改 InsertionPoint - rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); - } else { - rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); + if (!selectedEventId) { + // Fallback to single-buffer event id if selector cannot be built. + auto eventId = getEventAttr(rewriter, sync->eventIds[0]); + if (sync->isSyncWaitType()) + rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); + else + rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); + return; } + + if (sync->isSyncWaitType()) + rewriter.create(op->getLoc(), srcPipe, dstPipe, + selectedEventId); + else + rewriter.create(op->getLoc(), srcPipe, dstPipe, + selectedEventId); } Value SyncCodegen::GetBufferSelected(IRRewriter &rewriter, Operation *op, @@ -314,30 +315,46 @@ Value SyncCodegen::GetBufferSelected(IRRewriter &rewriter, Operation *op, if (SyncIndex2SelectBuffer.count(sync->GetSyncIndex())) { return SyncIndex2SelectBuffer[sync->GetSyncIndex()]; } - - auto parentLoop = op->getParentOfType(); - if (!parentLoop) return nullptr; - - Value counter; - if (loop2BufferCounter.count(parentLoop)) { - counter = loop2BufferCounter[parentLoop]; + + scf::ForOp baseLoop; + if (sync->lowestCommonAncestorBuffer) { + if (Operation *def = sync->lowestCommonAncestorBuffer.getDefiningOp()) + baseLoop = def->getParentOfType(); + } + if (!baseLoop && sync->GetForEndIndex().has_value()) { + int forEndIndex = sync->GetForEndIndex().value(); + if (forEndIndex >= 0 && static_cast(forEndIndex) < syncIR_.size()) { + auto *loopEndElem = dyn_cast(syncIR_[forEndIndex].get()); + if (loopEndElem) + baseLoop = dyn_cast_or_null(loopEndElem->elementOp); + } + } + if (!baseLoop) { + baseLoop = op->getParentOfType(); + } + if (!baseLoop) + return nullptr; + + // Get or build the nested-loop parity condition at the start of the base loop. + Value cond; + if (loop2BufferCounter.count(baseLoop.getOperation())) { + cond = loop2BufferCounter[baseLoop.getOperation()]; } else { - rewriter.setInsertionPointToStart(parentLoop.getBody()); - Value iv = parentLoop.getInductionVar(); - Value c2 = rewriter.create(op->getLoc(), 2); - counter = rewriter.create(op->getLoc(), iv, c2); - loop2BufferCounter[parentLoop] = counter; + cond = buildLoopNestParityCond(rewriter, baseLoop); + if (!cond) + return nullptr; + loop2BufferCounter[baseLoop.getOperation()] = cond; } - - rewriter.setInsertionPointAfter(counter.getDefiningOp()); - Value id0 = rewriter.create(op->getLoc(), sync->eventIds[0]); - Value id1 = rewriter.create(op->getLoc(), sync->eventIds[1]); - - Value isZero = rewriter.create(op->getLoc(), arith::CmpIPredicate::eq, counter, - rewriter.create(op->getLoc(), 0)); - - Value selected = rewriter.create(op->getLoc(), isZero, id0, id1); - + + rewriter.setInsertionPointAfter(cond.getDefiningOp()); + Value id0 = + rewriter.create(op->getLoc(), sync->eventIds[0]); + Value id1 = + rewriter.create(op->getLoc(), sync->eventIds[1]); + + // Select id1 on odd iterations, id0 on even iterations. + Value selected = rewriter.create(op->getLoc(), cond, id1, id0); + SyncIndex2SelectBuffer[sync->GetSyncIndex()] = selected; return selected; } diff --git a/lib/PTO/Transforms/PTOPlanMemory.cpp b/lib/PTO/Transforms/PTOPlanMemory.cpp index 24d24171..6cae57c5 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.cpp +++ b/lib/PTO/Transforms/PTOPlanMemory.cpp @@ -71,14 +71,17 @@ namespace { } // namespace -void MemLivenessAnalysis::build() { +LogicalResult MemLivenessAnalysis::build() { Region &funcRegion = func_.getBody(); Liveness live(func_); // Recursively obtaining IR information. RecursionIR(&funcRegion, live); + if (walkFailed) + return failure(); // the lifetime of the buffer. GenerateBufferLife(); //InitializeInplacePairList(); + return success(); } bool MemLivenessAnalysis::isLocalMemPlan() const { @@ -90,13 +93,21 @@ bool MemLivenessAnalysis::isGlobalWorkSpaceMemPlan() const { } void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { + if (walkFailed) + return; auto result = region->walk([&](Operation *op) { + if (walkFailed) + return WalkResult::interrupt(); // recursive control flow if (auto ifOp = dyn_cast(op)) { RecursiveIfOp(ifOp, live); + if (walkFailed) + return WalkResult::interrupt(); return WalkResult::skip(); } else if (auto forOp = dyn_cast(op)) { RecursiveForOp(forOp, live); + if (walkFailed) + return WalkResult::interrupt(); return WalkResult::skip(); } @@ -115,6 +126,29 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { if (failed(CheckLocalBufferAllocOp(op))) { return WalkResult::interrupt(); } + + // Optional multi-buffer intent: when present, PlanMemory will allocate + // ping/pong addresses for this buffer and materialize them via + // `pto.pointer_cast(addrs=[...,...])`. + if (auto mb = op->getAttr("pto.multi_buffer")) { + auto intAttr = dyn_cast(mb); + if (!intAttr) { + op->emitError("expected 'pto.multi_buffer' to be an integer attribute"); + return WalkResult::interrupt(); + } + int64_t num = intAttr.getInt(); + if (num == 1) { + // Explicitly marked as single-buffer: no action needed. + } else if (num == 2) { + // Record the multi-buffer factor for this alloc result. + if (!op->getResults().empty()) + buffer2MultiNum[op->getResult(0)] = 2; + } else { + op->emitError("only 'pto.multi_buffer = 2' is supported currently"); + return WalkResult::interrupt(); + } + } + UpdateOpBufferInfo(op, op->getResults()); return WalkResult::advance(); // } else if (isGlobalWorkSpaceMemPlan() && @@ -166,7 +200,8 @@ void MemLivenessAnalysis::RecursionIR(Region *region, Liveness live) { return WalkResult::advance(); }); if (result == WalkResult::interrupt()) { - llvm_unreachable("PlanMemory Traverse IR Failed! "); + walkFailed = true; + return; } } @@ -216,6 +251,8 @@ void MemLivenessAnalysis::UpdateForOpBufferAlias(scf::ForOp forOp) { } void MemLivenessAnalysis::RecursiveForOp(scf::ForOp forOp, Liveness live) { + if (walkFailed) + return; // Process the operation of ForOp as follows: // alloca %allocA // %0 = scf.for %arg4 = %c0 to %c1024 step %c128 iter_args(%arg5 = %4)-> @@ -228,6 +265,8 @@ void MemLivenessAnalysis::RecursiveForOp(scf::ForOp forOp, Liveness live) { UpdateOpGenInfo(forBeginSeq, GetLiveBuffersInLoop(forOp, live)); UpdateForOpInitArgsAlias(forOp); RecursionIR(&forOp.getRegion(), live); + if (walkFailed) + return; UpdateForOpBufferAlias(forOp); auto forEndSeq = UpdateLinearOperation(forOp.getOperation()); OpKillHandle(forEndSeq, live, forOp->getBlock()); @@ -257,6 +296,8 @@ void MemLivenessAnalysis::UpdateIfOpBufferAlias(scf::IfOp ifOp, } void MemLivenessAnalysis::RecursiveIfOp(scf::IfOp ifOp, Liveness live) { + if (walkFailed) + return; // Process the operation of IfOp as follows: // %0 = scf.if %cond -> (memref<16xf16, #pto.address_space>) // scf.yield %alloc0: memref<16xf16, #pto.address_space> @@ -264,12 +305,16 @@ void MemLivenessAnalysis::RecursiveIfOp(scf::IfOp ifOp, Liveness live) { // scf.yield %alloc1 : memref<16xf16, #pto.address_space> auto curIfThen = UpdateLinearOperation(ifOp.getOperation()); RecursionIR(&ifOp.getThenRegion(), live); + if (walkFailed) + return; auto curIfElse = UpdateLinearOperation(ifOp.getOperation()); UpdateIfOpBufferAlias(ifOp, ifOp.thenYield()); auto curIfEnd = curIfElse; if (ifOp.elseBlock()) { RecursionIR(&ifOp.getElseRegion(), live); + if (walkFailed) + return; curIfEnd = UpdateLinearOperation(ifOp.getOperation()); UpdateIfOpBufferAlias(ifOp, ifOp.elseYield()); } @@ -1955,7 +2000,8 @@ void PlanMemoryPass::runOnOperation() { // } MemLivenessAnalysis memLiveness(funcOp, this->memMode); - memLiveness.build(); + if (failed(memLiveness.build())) + return signalPassFailure(); MemPlan memPlan(this->memMode, this->enableGlobalReuse, this->enablePrintMemoryAllocatedSize, diff --git a/lib/PTO/Transforms/PTOPlanMemory.h b/lib/PTO/Transforms/PTOPlanMemory.h index e7cff4f7..6c76e6ff 100644 --- a/lib/PTO/Transforms/PTOPlanMemory.h +++ b/lib/PTO/Transforms/PTOPlanMemory.h @@ -257,7 +257,7 @@ class MemLivenessAnalysis { MemLivenessAnalysis(func::FuncOp func, MemPlanMode planMode) : func_(func), planMode(planMode) {} - void build(); + LogicalResult build(); /// linear operation info. SmallVector> linearOperation; @@ -287,6 +287,7 @@ class MemLivenessAnalysis { bool isGlobalWorkSpaceMemPlan() const; private: + bool walkFailed{false}; void RecursionIR(Region *region, Liveness live); /// Get the buffer used within the loop and defined outside the loop. diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index fd868f27..65eb543c 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -3591,6 +3591,81 @@ struct PTOWaitFlagToEmitC : public OpConversionPattern { } }; +struct PTOSetFlagDynToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFlagDynOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + std::string srcTok = pipeTokFromPipeAttr(op.getSrcPipe()); + std::string dstTok = pipeTokFromPipeAttr(op.getDstPipe()); + + // Cast the dynamic event id into the ISA event type to keep ABI stable. + auto eventTy = emitc::OpaqueType::get(ctx, "event_t"); + auto castTyAttr = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, "event_t")}); + Value eventVal = + rewriter + .create(loc, eventTy, "static_cast", + /*args=*/ArrayAttr{}, + /*template_args=*/castTyAttr, + /*operands=*/ValueRange{adaptor.getEventId()}) + .getResult(0); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventVal}); + return success(); + } +}; + +struct PTOWaitFlagDynToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::WaitFlagDynOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + Location loc = op.getLoc(); + + std::string srcTok = pipeTokFromPipeAttr(op.getSrcPipe()); + std::string dstTok = pipeTokFromPipeAttr(op.getDstPipe()); + + auto eventTy = emitc::OpaqueType::get(ctx, "event_t"); + auto castTyAttr = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, "event_t")}); + Value eventVal = + rewriter + .create(loc, eventTy, "static_cast", + /*args=*/ArrayAttr{}, + /*template_args=*/castTyAttr, + /*operands=*/ValueRange{adaptor.getEventId()}) + .getResult(0); + + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "wait_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventVal}); + return success(); + } +}; + struct PTOGetBufToEmitC : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -7181,9 +7256,11 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 4065cd67..6de65208 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -548,6 +548,12 @@ struct PTOViewToMemrefPass auto allocType = MemRefType::get(shape, elemTy, allocLayout, tbTy.getMemorySpace()); Value alloc = rewriter.create(loc, allocType); + // Propagate multi-buffer intent so PlanMemory can allocate ping/pong + // addresses for this local buffer. + if (auto mb = op->getAttr("pto.multi_buffer")) { + alloc.getDefiningOp()->setAttr("pto.multi_buffer", mb); + } + // BindTileOp 的 Builder 会自动处理空的 Value,将其视为静态维度 auto bindOp = rewriter.create( loc, targetType, alloc, vRow ? vRow : Value(), vCol ? vCol : Value(), @@ -1100,8 +1106,12 @@ struct PTOViewToMemrefPass func.walk([&](mlir::pto::TReshapeOp op) { reshapes.push_back(op); }); for (auto op : reshapes) { + // NOTE: After Stage 0.5 (alloc_tile -> memref.alloc + bind_tile), the + // operand types of tile_buf view ops may be rewritten to MemRefType, + // which temporarily breaks the ODS-typed accessors (they would + // llvm::cast(...) and assert). Use raw operands here. Value src = op->getOperand(0); - auto tbTy = dyn_cast(op->getResult(0).getType()); + auto tbTy = dyn_cast(op.getResult().getType()); if (!tbTy) { op.emitError("treshape result must be tile_buf type"); signalPassFailure(); @@ -1118,8 +1128,10 @@ struct PTOViewToMemrefPass func.walk([&](mlir::pto::BitcastOp op) { bitcasts.push_back(op); }); for (auto op : bitcasts) { + // See note above: avoid typed accessors on ops whose operands have been + // rewritten to memref values in Stage 0.5. Value src = op->getOperand(0); - auto tbTy = dyn_cast(op->getResult(0).getType()); + auto tbTy = dyn_cast(op.getResult().getType()); if (!tbTy) { op.emitError("bitcast result must be tile_buf type"); signalPassFailure(); diff --git a/test/samples/Sync/test_inject_sync_multibuf_pingpong.py b/test/samples/Sync/test_inject_sync_multibuf_pingpong.py new file mode 100644 index 00000000..858e13bb --- /dev/null +++ b/test/samples/Sync/test_inject_sync_multibuf_pingpong.py @@ -0,0 +1,70 @@ +from mlir.ir import ( + Context, + Location, + Module, + InsertionPoint, + F16Type, + IndexType, + IntegerAttr, + IntegerType, + MemRefType, +) +from mlir.dialects import arith, func, memref, scf, pto + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f16 = F16Type.get(ctx) + idx = IndexType.get(ctx) + i32 = IntegerType.get_signless(32, ctx) + + gm = pto.AddressSpaceAttr.get(pto.AddressSpace.GM, ctx) + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + + gm_ty = MemRefType.get([16, 16, 16], f16, memory_space=gm) + ub_ty = MemRefType.get([16, 16, 16], f16, memory_space=vec) + + fn_ty = func.FunctionType.get([gm_ty, gm_ty], []) + with InsertionPoint(m.body): + fn = func.FuncOp("test_inject_sync_multibuf_pingpong_py", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src, dst = entry.arguments + + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c4 = arith.ConstantOp(idx, 4).result + + alloc = memref.AllocOp(ub_ty, [], []) + ub = alloc.result + alloc.operation.attributes["pto.multi_buffer"] = IntegerAttr.get( + i32, 2 + ) + + # Loop-carried hazard: + # - TLOAD writes to UB on PIPE_MTE2. + # - TSTORE reads from UB on PIPE_MTE3. + # With multi-buffer enabled, the compiler should materialize a + # ping/pong selector and use dynamic event-id sync on the + # back-edge dependency. + loop = scf.ForOp(c0, c4, c1, []) + with InsertionPoint(loop.body): + pto.TLoadOp(None, src, ub) + pto.TStoreOp(None, ub, dst) + scf.YieldOp([]) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) + diff --git a/test/samples/runop.sh b/test/samples/runop.sh index 0523f85e..7fa5b743 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -282,6 +282,28 @@ process_one_dir() { fi fi + # Regression guard: planned ping/pong buffers must be materialized as a + # loop-local selector + dynamic event-id set/wait on back-edge deps. + if [[ "$base" == "test_inject_sync_multibuf_pingpong" ]]; then + if ! grep -Fq "static_cast" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing dynamic event-id (static_cast) in generated C++" + overall=1 + continue + fi + local tassign_count + tassign_count="$(grep -c "TASSIGN(" "$cpp")" + if [[ "${tassign_count}" -lt 2 ]]; then + echo -e "${A}(${base}.py)\tFAIL\texpected >=2 TASSIGN calls (ping/pong pointer_cast hoisting)" + overall=1 + continue + fi + if ! grep -Fq "?" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing ternary select for ping/pong buffer selection" + overall=1 + continue + fi + fi + # Regression guard: intra-pipe dependencies must be serialized by a # per-pipe barrier (PyPTO expects `bar_v` / `bar_m` behavior). if [[ "$base" == "test_inject_sync_intra_pipe_barrier" ]]; then diff --git a/tools/ptoas/ptoas.cpp b/tools/ptoas/ptoas.cpp index 480c763d..33ba080f 100644 --- a/tools/ptoas/ptoas.cpp +++ b/tools/ptoas/ptoas.cpp @@ -664,6 +664,9 @@ int main(int argc, char **argv) { } } + // Materialize ping/pong selection for planned multi-buffer pointer_cast ops. + pm.addNestedPass(pto::createPTOEnableMultiBufferPass()); + // pm.addNestedPass(pto::createPTORemoveRedundantBarrierPass()); // pm.addNestedPass(pto::createPTOHighDimLoweringPass()); // pm.addNestedPass(pto::createPTOVFloopGatherPass()); From 0f04fbce21570caefedeb4a8c823348bc13ab6b2 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 5 Mar 2026 12:55:21 +0800 Subject: [PATCH 2/7] EmitC: fix set/wait dyn argument emission emitc.call_opaque requires an IntegerAttr placeholder to print SSA operands. Add the operand placeholder for event_id so set_flag/wait_flag receive the dynamic event argument, and extend the Sync multibuf runop guard to catch missing 3rd argument. --- lib/PTO/Transforms/PTOToEmitC.cpp | 4 ++++ test/samples/runop.sh | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 65eb543c..62dfecf6 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -3615,9 +3615,12 @@ struct PTOSetFlagDynToEmitC /*operands=*/ValueRange{adaptor.getEventId()}) .getResult(0); + // NOTE: emitc.call_opaque mixes literal tokens (OpaqueAttr) and SSA operands + // via integer placeholders. IntegerAttr(0) prints operands[0]. auto argsAttr = rewriter.getArrayAttr({ emitc::OpaqueAttr::get(ctx, srcTok), emitc::OpaqueAttr::get(ctx, dstTok), + IntegerAttr::get(IndexType::get(ctx), 0), }); rewriter.replaceOpWithNewOp( @@ -3655,6 +3658,7 @@ struct PTOWaitFlagDynToEmitC auto argsAttr = rewriter.getArrayAttr({ emitc::OpaqueAttr::get(ctx, srcTok), emitc::OpaqueAttr::get(ctx, dstTok), + IntegerAttr::get(IndexType::get(ctx), 0), }); rewriter.replaceOpWithNewOp( diff --git a/test/samples/runop.sh b/test/samples/runop.sh index 7fa5b743..1b3f2d3a 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -290,6 +290,19 @@ process_one_dir() { overall=1 continue fi + # Ensure the dynamic event id is actually passed to set/wait flag calls. + # (CallOpaqueOp requires an IntegerAttr placeholder, otherwise the operand + # is silently dropped and we emit a 2-arg wait_flag/set_flag.) + if ! grep -Eq "wait_flag\\(PIPE_MTE3, PIPE_MTE2, v[0-9]+\\)" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing dynamic wait_flag(..., ) for ping/pong back-edge" + overall=1 + continue + fi + if ! grep -Eq "set_flag\\(PIPE_MTE3, PIPE_MTE2, v[0-9]+\\)" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing dynamic set_flag(..., ) for ping/pong back-edge" + overall=1 + continue + fi local tassign_count tassign_count="$(grep -c "TASSIGN(" "$cpp")" if [[ "${tassign_count}" -lt 2 ]]; then From 3a1958d8f01f7f9bdb1ccb94c77f94703952c90c Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 5 Mar 2026 13:13:35 +0800 Subject: [PATCH 3/7] EnableMultiBuffer: rematerialize alias chain in loops - Track view-like alias closures (bind_tile/subview/casts) from multi-address pointer_cast. - Build ping/pong selector in the LCA loop and rematerialize loop-local alias ops so tile allocations also switch ping/pong. - Update multibuf pingpong sample to use alloc_tile + TLOAD/TSTORE so the generated C++ builds on A2/A3 pto-isa. --- lib/PTO/Transforms/EnableMultiBuffer.cpp | 182 +++++++++++++++--- .../test_inject_sync_multibuf_pingpong.py | 53 ++--- 2 files changed, 188 insertions(+), 47 deletions(-) diff --git a/lib/PTO/Transforms/EnableMultiBuffer.cpp b/lib/PTO/Transforms/EnableMultiBuffer.cpp index 8f9aabb7..c798fe78 100644 --- a/lib/PTO/Transforms/EnableMultiBuffer.cpp +++ b/lib/PTO/Transforms/EnableMultiBuffer.cpp @@ -17,8 +17,10 @@ #include "PTO/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseSet.h" @@ -120,19 +122,88 @@ struct PTOEnableMultiBufferPass return signalPassFailure(); } - // Collect the enclosing loop for each use site. The resulting LCA is the - // loop in which we materialize the ping/pong selector. + // Track view-like alias chains so we can materialize ping/pong selection + // even when the pointer_cast is consumed by a BindTile/SubView outside + // the loop (common after alloc_tile -> memref.alloc + bind_tile lowering). + DenseMap aliasResult2Op; + DenseMap aliasResult2Source; + + SmallVector closure; + SmallVector worklist{op.getResult()}; + llvm::DenseSet visited; + + // Collect the enclosing loop for each loop use site across the alias + // closure. The resulting LCA is the loop in which we materialize the + // ping/pong selector and any needed view-like rematerializations. SmallVector useLoops; - llvm::DenseSet seen; - for (OpOperand &use : op.getResult().getUses()) { - Operation *owner = use.getOwner(); + llvm::DenseSet seenLoops; + + auto recordUseLoop = [&](Operation *owner) { if (!owner) - continue; + return; scf::ForOp enclosing = owner->getParentOfType(); if (!enclosing) - continue; - if (seen.insert(enclosing.getOperation()).second) + return; + if (seenLoops.insert(enclosing.getOperation()).second) useLoops.push_back(enclosing); + }; + + while (!worklist.empty()) { + Value v = worklist.pop_back_val(); + if (!visited.insert(v).second) + continue; + closure.push_back(v); + + for (OpOperand &use : v.getUses()) { + Operation *owner = use.getOwner(); + if (!owner) + continue; + + // Alias ops: propagate from source -> result. + if (auto bt = dyn_cast(owner)) { + if (use.getOperandNumber() == 0) { + Value res = bt.getResult(); + if (aliasResult2Op.try_emplace(res, owner).second) { + aliasResult2Source[res] = v; + worklist.push_back(res); + } + continue; + } + } + if (auto sv = dyn_cast(owner)) { + if (use.getOperandNumber() == 0) { + Value res = sv.getResult(); + if (aliasResult2Op.try_emplace(res, owner).second) { + aliasResult2Source[res] = v; + worklist.push_back(res); + } + continue; + } + } + if (auto rc = dyn_cast(owner)) { + if (use.getOperandNumber() == 0) { + Value res = rc.getResult(); + if (aliasResult2Op.try_emplace(res, owner).second) { + aliasResult2Source[res] = v; + worklist.push_back(res); + } + continue; + } + } + if (auto cast = dyn_cast(owner)) { + if (use.getOperandNumber() == 0) { + Value res = cast.getResult(); + if (aliasResult2Op.try_emplace(res, owner).second) { + aliasResult2Source[res] = v; + worklist.push_back(res); + } + continue; + } + } + + // Non-alias use: record for loop LCA computation. + recordUseLoop(owner); + } } scf::ForOp baseLoop = lowestCommonAncestorLoop(useLoops); @@ -149,13 +220,15 @@ struct PTOEnableMultiBufferPass continue; } - // If the original pointer_cast is used as an operand of the selected base - // loop op, we cannot replace that use with a value defined inside the - // loop. Treat this as unsupported to avoid miscompilation. - for (OpOperand &use : op.getResult().getUses()) { - if (use.getOwner() == baseLoop.getOperation()) { - op.emitError("unsupported: multi-buffer pointer_cast used as an operand of the base scf.for"); - return signalPassFailure(); + // If any value in the alias closure is used as an operand of the selected + // base loop op, we cannot safely rewrite that use with a value defined + // inside the loop. Treat this as unsupported to avoid miscompilation. + for (Value v : closure) { + for (OpOperand &use : v.getUses()) { + if (use.getOwner() == baseLoop.getOperation()) { + op.emitError("unsupported: multi-buffer value used as an operand of the base scf.for"); + return signalPassFailure(); + } } } @@ -198,15 +271,78 @@ struct PTOEnableMultiBufferPass Value selected = rewriter.create( op.getLoc(), cond, ptr1.getResult(), ptr0.getResult()); - // Replace uses that are inside the base loop body (including nested ops). - SmallVector toReplace; - for (OpOperand &use : op.getResult().getUses()) { - Operation *owner = use.getOwner(); - if (owner && isInLoopBody(owner, baseLoop)) - toReplace.push_back(&use); + // Materialize loop-local equivalents of values in the alias closure. + DenseMap loopLocal; + loopLocal[op.getResult()] = selected; + Operation *insertAfter = selected.getDefiningOp(); + + auto materialize = [&](Value v, auto &materializeRef) -> Value { + if (auto it = loopLocal.find(v); it != loopLocal.end()) + return it->second; + + // If this value is already defined inside the base loop body, reuse it + // (the source operands will be rewritten separately as needed). + if (Operation *def = v.getDefiningOp()) { + if (isInLoopBody(def, baseLoop)) { + loopLocal[v] = v; + return v; + } + } else if (auto barg = dyn_cast(v)) { + if (barg.getOwner() == baseLoop.getBody()) { + loopLocal[v] = v; + return v; + } + } + + auto it = aliasResult2Op.find(v); + if (it == aliasResult2Op.end()) + return Value(); + + Operation *aliasOp = it->second; + Value src = aliasResult2Source.lookup(v); + Value localSrc = materializeRef(src, materializeRef); + if (!localSrc) + return Value(); + + // If the alias op already lives inside the base loop body, we expect + // its operands to be rewritten via the generic use replacement below. + if (isInLoopBody(aliasOp, baseLoop)) { + loopLocal[v] = v; + return v; + } + + rewriter.setInsertionPointAfter(insertAfter); + mlir::IRMapping mapping; + mapping.map(src, localSrc); + Operation *cloned = rewriter.clone(*aliasOp, mapping); + insertAfter = cloned; + + Value res = cloned->getResult(0); + loopLocal[v] = res; + return res; + }; + + // Replace uses that are inside the base loop body (including nested ops) + // with the loop-local equivalents. + for (Value v : closure) { + SmallVector toReplace; + for (OpOperand &use : v.getUses()) { + Operation *owner = use.getOwner(); + if (owner && isInLoopBody(owner, baseLoop)) + toReplace.push_back(&use); + } + if (toReplace.empty()) + continue; + + Value repl = materialize(v, materialize); + if (!repl) { + op.emitError("failed to materialize loop-local alias for multi-buffer value"); + return signalPassFailure(); + } + + for (OpOperand *use : toReplace) + use->set(repl); } - for (OpOperand *use : toReplace) - use->set(selected); if (op.getResult().use_empty()) op.erase(); diff --git a/test/samples/Sync/test_inject_sync_multibuf_pingpong.py b/test/samples/Sync/test_inject_sync_multibuf_pingpong.py index 858e13bb..81f3bb8b 100644 --- a/test/samples/Sync/test_inject_sync_multibuf_pingpong.py +++ b/test/samples/Sync/test_inject_sync_multibuf_pingpong.py @@ -1,15 +1,6 @@ -from mlir.ir import ( - Context, - Location, - Module, - InsertionPoint, - F16Type, - IndexType, - IntegerAttr, - IntegerType, - MemRefType, -) -from mlir.dialects import arith, func, memref, scf, pto +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import arith, func, scf, pto +from mlir.ir import F16Type, IndexType, IntegerAttr, IntegerType def build(): @@ -23,13 +14,20 @@ def build(): idx = IndexType.get(ctx) i32 = IntegerType.get_signless(32, ctx) - gm = pto.AddressSpaceAttr.get(pto.AddressSpace.GM, ctx) vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) - gm_ty = MemRefType.get([16, 16, 16], f16, memory_space=gm) - ub_ty = MemRefType.get([16, 16, 16], f16, memory_space=vec) + ptr_f16 = pto.PtrType.get(f16, ctx) + tv2_f16 = pto.TensorViewType.get(2, f16, ctx) + tile_view_16 = pto.PartitionTensorViewType.get([16, 16], f16, ctx) - fn_ty = func.FunctionType.get([gm_ty, gm_ty], []) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + tile_buf_16 = pto.TileBufType.get([16, 16], f16, vec, [16, 16], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f16, ptr_f16], []) with InsertionPoint(m.body): fn = func.FuncOp("test_inject_sync_multibuf_pingpong_py", fn_ty) entry = fn.add_entry_block() @@ -40,12 +38,20 @@ def build(): c0 = arith.ConstantOp(idx, 0).result c1 = arith.ConstantOp(idx, 1).result c4 = arith.ConstantOp(idx, 4).result - - alloc = memref.AllocOp(ub_ty, [], []) + c16 = arith.ConstantOp(idx, 16).result + + tv_in = pto.MakeTensorViewOp(tv2_f16, src, [c16, c16], [c16, c1]).result + tv_out = pto.MakeTensorViewOp(tv2_f16, dst, [c16, c16], [c16, c1]).result + sv_in = pto.PartitionViewOp( + tile_view_16, tv_in, offsets=[c0, c0], sizes=[c16, c16] + ).result + sv_out = pto.PartitionViewOp( + tile_view_16, tv_out, offsets=[c0, c0], sizes=[c16, c16] + ).result + + alloc = pto.AllocTileOp(tile_buf_16) + alloc.operation.attributes["pto.multi_buffer"] = IntegerAttr.get(i32, 2) ub = alloc.result - alloc.operation.attributes["pto.multi_buffer"] = IntegerAttr.get( - i32, 2 - ) # Loop-carried hazard: # - TLOAD writes to UB on PIPE_MTE2. @@ -55,8 +61,8 @@ def build(): # back-edge dependency. loop = scf.ForOp(c0, c4, c1, []) with InsertionPoint(loop.body): - pto.TLoadOp(None, src, ub) - pto.TStoreOp(None, ub, dst) + pto.TLoadOp(None, sv_in, ub) + pto.TStoreOp(None, ub, sv_out) scf.YieldOp([]) func.ReturnOp([]) @@ -67,4 +73,3 @@ def build(): if __name__ == "__main__": print(build()) - From 5c7d45934ffb6f15f22745480ec168bf589521dc Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 5 Mar 2026 13:24:52 +0800 Subject: [PATCH 4/7] EnableMultiBuffer: select addr to avoid Tile pointer casts Materialize ping/pong by selecting between i64 addresses and building a loop-local PointerCastOp. This keeps bind_tile lowering able to trace the defining PointerCastOp and avoids generating C++ that casts Tile<> to __ubuf__ pointers (which breaks A2/A3 compilation). Update the multibuf runop guard accordingly. --- lib/PTO/Transforms/EnableMultiBuffer.cpp | 22 +++++++++++---------- test/samples/runop.sh | 25 ++++++++++++++++++------ 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/lib/PTO/Transforms/EnableMultiBuffer.cpp b/lib/PTO/Transforms/EnableMultiBuffer.cpp index c798fe78..938d7b49 100644 --- a/lib/PTO/Transforms/EnableMultiBuffer.cpp +++ b/lib/PTO/Transforms/EnableMultiBuffer.cpp @@ -246,13 +246,6 @@ struct PTOEnableMultiBufferPass return signalPassFailure(); } - auto ptr0 = rewriter.create( - op.getLoc(), op.getType(), ValueRange{c0}, vRow ? vRow : Value(), - vCol ? vCol : Value(), config); - auto ptr1 = rewriter.create( - op.getLoc(), op.getType(), ValueRange{c1}, vRow ? vRow : Value(), - vCol ? vCol : Value(), config); - // Build (or reuse) loop-parity condition and select the active buffer. Value cond; auto it = loop2Cond.find(baseLoop.getOperation()); @@ -267,14 +260,23 @@ struct PTOEnableMultiBufferPass loop2Cond[baseLoop.getOperation()] = cond; } + // Select the *address* and build a loop-local pointer_cast. This keeps + // `pto.bind_tile` lowering able to trace back to a defining PointerCastOp + // and avoids relying on C++ casts from Tile<> to raw pointers (which are + // not guaranteed to be supported by the ISA headers). rewriter.setInsertionPointAfter(cond.getDefiningOp()); - Value selected = rewriter.create( - op.getLoc(), cond, ptr1.getResult(), ptr0.getResult()); + Value selectedAddr = + rewriter.create(op.getLoc(), cond, c1, c0); + rewriter.setInsertionPointAfter(selectedAddr.getDefiningOp()); + auto selectedCast = rewriter.create( + op.getLoc(), op.getType(), ValueRange{selectedAddr}, + vRow ? vRow : Value(), vCol ? vCol : Value(), config); + Value selected = selectedCast.getResult(); // Materialize loop-local equivalents of values in the alias closure. DenseMap loopLocal; loopLocal[op.getResult()] = selected; - Operation *insertAfter = selected.getDefiningOp(); + Operation *insertAfter = selectedCast.getOperation(); auto materialize = [&](Value v, auto &materializeRef) -> Value { if (auto it = loopLocal.find(v); it != loopLocal.end()) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index 1b3f2d3a..1d628aa4 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -303,15 +303,28 @@ process_one_dir() { overall=1 continue fi - local tassign_count - tassign_count="$(grep -c "TASSIGN(" "$cpp")" - if [[ "${tassign_count}" -lt 2 ]]; then - echo -e "${A}(${base}.py)\tFAIL\texpected >=2 TASSIGN calls (ping/pong pointer_cast hoisting)" + if ! grep -Fq "?" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing ternary select for ping/pong buffer selection" overall=1 continue fi - if ! grep -Fq "?" "$cpp"; then - echo -e "${A}(${base}.py)\tFAIL\tmissing ternary select for ping/pong buffer selection" + if ! grep -Fq "TASSIGN(" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing TASSIGN for selected ping/pong address" + overall=1 + continue + fi + if ! grep -Eq "int64_t v[0-9]+ = 0;" "$cpp" || ! grep -Eq "int64_t v[0-9]+ = 512;" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\texpected ping/pong address constants (0 and 512) in generated C++" + overall=1 + continue + fi + if ! grep -Eq "int64_t v[0-9]+ = .*\\? .* : .*;" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing int64 ternary selection for ping/pong address" + overall=1 + continue + fi + if grep -Fq "(__ubuf__" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tunexpected Tile->pointer cast (may break NPU compilation)" overall=1 continue fi From e0fdc6d1c4a0fa08e3cb369528a2c23e84fd310c Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Thu, 5 Mar 2026 13:28:31 +0800 Subject: [PATCH 5/7] runop: relax multibuf ping/pong address guard The ping/pong base addresses are an implementation detail of PlanMemory. Check for >=2 distinct int64 constants + ternary address selection, rather than hard-coding 0/512. --- test/samples/runop.sh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index 1d628aa4..0f2cf14e 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -313,8 +313,10 @@ process_one_dir() { overall=1 continue fi - if ! grep -Eq "int64_t v[0-9]+ = 0;" "$cpp" || ! grep -Eq "int64_t v[0-9]+ = 512;" "$cpp"; then - echo -e "${A}(${base}.py)\tFAIL\texpected ping/pong address constants (0 and 512) in generated C++" + local uniq_i64_count + uniq_i64_count="$(awk '/^[[:space:]]*int64_t v[0-9]+ = -?[0-9]+;[[:space:]]*$/{print $4}' "$cpp" | tr -d ';' | sort -u | wc -l | tr -d ' ')" + if [[ -z "${uniq_i64_count}" || "${uniq_i64_count}" -lt 2 ]]; then + echo -e "${A}(${base}.py)\tFAIL\texpected >=2 distinct int64 constants for ping/pong addresses" overall=1 continue fi From 4183bf4ce02cd7821ac82adb605f574cfbedcc84 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 6 Mar 2026 21:24:01 +0800 Subject: [PATCH 6/7] test(sync): add subset-based multibuffer pingpong regression --- ...st_inject_sync_multibuf_subset_pingpong.py | 82 +++++++++++++++++++ test/samples/runop.sh | 25 ++++++ 2 files changed, 107 insertions(+) create mode 100644 test/samples/Sync/test_inject_sync_multibuf_subset_pingpong.py diff --git a/test/samples/Sync/test_inject_sync_multibuf_subset_pingpong.py b/test/samples/Sync/test_inject_sync_multibuf_subset_pingpong.py new file mode 100644 index 00000000..9ac857bc --- /dev/null +++ b/test/samples/Sync/test_inject_sync_multibuf_subset_pingpong.py @@ -0,0 +1,82 @@ +from mlir.ir import Context, Location, Module, InsertionPoint +from mlir.dialects import arith, func, scf, pto +from mlir.ir import F16Type, IndexType + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + + f16 = F16Type.get(ctx) + idx = IndexType.get(ctx) + + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + + ptr_f16 = pto.PtrType.get(f16, ctx) + tv2_f16 = pto.TensorViewType.get(2, f16, ctx) + tile_view_16 = pto.PartitionTensorViewType.get([16, 16], f16, ctx) + + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + workspace_ty = pto.TileBufType.get([16, 32], f16, vec, [16, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f16, ptr_f16], []) + with InsertionPoint(m.body): + fn = func.FuncOp("test_inject_sync_multibuf_subset_pingpong_py", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + src, dst = entry.arguments + + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c2 = arith.ConstantOp(idx, 2).result + c4 = arith.ConstantOp(idx, 4).result + c16 = arith.ConstantOp(idx, 16).result + + tv_in = pto.MakeTensorViewOp(tv2_f16, src, [c16, c16], [c16, c1]).result + tv_out = pto.MakeTensorViewOp(tv2_f16, dst, [c16, c16], [c16, c1]).result + sv_in = pto.PartitionViewOp( + tile_view_16, tv_in, offsets=[c0, c0], sizes=[c16, c16] + ).result + sv_out = pto.PartitionViewOp( + tile_view_16, tv_out, offsets=[c0, c0], sizes=[c16, c16] + ).result + + # Hand-written multibuffer style: + # one workspace tile split into ping/pong by subset. + workspace = pto.AllocTileOp(workspace_ty).result + ping = pto.SubsetOp(workspace, [c0, c0], sizes=[16, 16]).result + pong = pto.SubsetOp(workspace, [c0, c16], sizes=[16, 16]).result + + loop = scf.ForOp(c0, c4, c1, []) + with InsertionPoint(loop.body): + parity = arith.RemUIOp(loop.induction_variable, c2).result + is_ping = arith.CmpIOp(arith.CmpIPredicate.eq, parity, c0).result + + slot_if = scf.IfOp(is_ping, [], hasElse=True) + with InsertionPoint(slot_if.then_block): + pto.TLoadOp(None, sv_in, ping) + pto.TStoreOp(None, ping, sv_out) + scf.YieldOp([]) + with InsertionPoint(slot_if.else_block): + pto.TLoadOp(None, sv_in, pong) + pto.TStoreOp(None, pong, sv_out) + scf.YieldOp([]) + + scf.YieldOp([]) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index 0f2cf14e..7ca49381 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -332,6 +332,31 @@ process_one_dir() { fi fi + # Regression guard: handwritten multibuffer (subset ping/pong) should keep + # subset-based slot split and branch-local load/store structure. + if [[ "$base" == "test_inject_sync_multibuf_subset_pingpong" ]]; then + if ! grep -Fq "pto.subset" "$pto_input"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing pto.subset in source PTO IR" + overall=1 + continue + fi + local tassign_count + tassign_count="$(grep -c "TASSIGN(" "$cpp" || true)" + if [[ -z "${tassign_count}" || "${tassign_count}" -lt 3 ]]; then + echo -e "${A}(${base}.py)\tFAIL\texpected workspace+ping+pong TASSIGN lowering" + overall=1 + continue + fi + local tload_count tstore_count + tload_count="$(grep -c "TLOAD(" "$cpp" || true)" + tstore_count="$(grep -c "TSTORE(" "$cpp" || true)" + if [[ -z "${tload_count}" || "${tload_count}" -lt 2 || -z "${tstore_count}" || "${tstore_count}" -lt 2 ]]; then + echo -e "${A}(${base}.py)\tFAIL\texpected ping/pong branch-local TLOAD/TSTORE" + overall=1 + continue + fi + fi + # Regression guard: intra-pipe dependencies must be serialized by a # per-pipe barrier (PyPTO expects `bar_v` / `bar_m` behavior). if [[ "$base" == "test_inject_sync_intra_pipe_barrier" ]]; then From 072789ebba4573f297b281e111bb37e43416bc46 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Sun, 8 Mar 2026 20:42:17 +0800 Subject: [PATCH 7/7] test(sync): mark subset pingpong as multi_buffer and assert annotation --- .../Sync/test_inject_sync_multibuf_subset_pingpong.py | 9 +++++++-- test/samples/runop.sh | 5 +++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/samples/Sync/test_inject_sync_multibuf_subset_pingpong.py b/test/samples/Sync/test_inject_sync_multibuf_subset_pingpong.py index 9ac857bc..f1374f8a 100644 --- a/test/samples/Sync/test_inject_sync_multibuf_subset_pingpong.py +++ b/test/samples/Sync/test_inject_sync_multibuf_subset_pingpong.py @@ -1,6 +1,6 @@ from mlir.ir import Context, Location, Module, InsertionPoint from mlir.dialects import arith, func, scf, pto -from mlir.ir import F16Type, IndexType +from mlir.ir import F16Type, IndexType, IntegerAttr, IntegerType def build(): @@ -12,6 +12,7 @@ def build(): f16 = F16Type.get(ctx) idx = IndexType.get(ctx) + i32 = IntegerType.get_signless(32, ctx) vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) @@ -51,7 +52,11 @@ def build(): # Hand-written multibuffer style: # one workspace tile split into ping/pong by subset. - workspace = pto.AllocTileOp(workspace_ty).result + # `pto.multi_buffer=2` tells PlanMemory/InsertSync this is a + # ping/pong candidate. + alloc = pto.AllocTileOp(workspace_ty) + alloc.operation.attributes["pto.multi_buffer"] = IntegerAttr.get(i32, 2) + workspace = alloc.result ping = pto.SubsetOp(workspace, [c0, c0], sizes=[16, 16]).result pong = pto.SubsetOp(workspace, [c0, c16], sizes=[16, 16]).result diff --git a/test/samples/runop.sh b/test/samples/runop.sh index 7ca49381..f6d56286 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -340,6 +340,11 @@ process_one_dir() { overall=1 continue fi + if ! grep -Fq "pto.multi_buffer = 2 : i32" "$pto_input"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing pto.multi_buffer=2 annotation for subset ping/pong" + overall=1 + continue + fi local tassign_count tassign_count="$(grep -c "TASSIGN(" "$cpp" || true)" if [[ -z "${tassign_count}" || "${tassign_count}" -lt 3 ]]; then