diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 09441acb..7f1ac733 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -1166,6 +1166,45 @@ def WaitFlagOp : PTO_Op<"wait_flag"> { }]; } +def SetFlagDynOp : PTO_Op<"set_flag_dyn"> { + let summary = "Set synchronization flag between pipes (dynamic event id)"; + let description = [{ + `pto.set_flag_dyn` is the dynamic form of `pto.set_flag`. It is intended + for multi-buffer synchronization patterns where the event id is selected + at runtime (e.g. ping-pong between two event ids in a loop). + + The event id value must be in the valid backend range (typically 0..7). + }]; + + let arguments = (ins + PTO_PipeAttr:$src_pipe, + PTO_PipeAttr:$dst_pipe, + Index:$event_id + ); + let results = (outs); + let assemblyFormat = [{ + `[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict + }]; +} + +def WaitFlagDynOp : PTO_Op<"wait_flag_dyn"> { + let summary = "Wait for synchronization flag (dynamic event id)"; + let description = [{ + `pto.wait_flag_dyn` is the dynamic form of `pto.wait_flag`. See + `pto.set_flag_dyn` for the intended use cases. + }]; + + let arguments = (ins + PTO_PipeAttr:$src_pipe, + PTO_PipeAttr:$dst_pipe, + Index:$event_id + ); + let results = (outs); + let assemblyFormat = [{ + `[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict + }]; +} + //===----------------------------------------------------------------------===// // Buffer-ID Synchronization (A5) //===----------------------------------------------------------------------===// diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index 5249d4b0..de80b298 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -141,6 +141,9 @@ void PTOIRTranslator::RecursionIR(Region *region) { else if (auto subViewOp = dyn_cast(op)) { UpdateAliasBufferInfo(subViewOp.getResult(), subViewOp.getSource()); } + else if (auto subsetOp = dyn_cast(op)) { + UpdateAliasBufferInfo(subsetOp.getResult(), subsetOp.getSource()); + } else if (auto memrefSubView = dyn_cast(op)) { UpdateAliasBufferInfo(memrefSubView.getResult(), memrefSubView.getSource()); } diff --git a/lib/PTO/Transforms/InsertSync/PTOInsertSync.cpp b/lib/PTO/Transforms/InsertSync/PTOInsertSync.cpp index dd30bf8c..98973236 100644 --- a/lib/PTO/Transforms/InsertSync/PTOInsertSync.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOInsertSync.cpp @@ -57,7 +57,8 @@ struct PTOInsertSyncPass : public mlir::pto::impl::PTOInsertSyncBase(op)) { hasExplicitSync = true; return WalkResult::interrupt(); diff --git a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp index 4e8c5c28..4c942276 100644 --- a/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp +++ b/lib/PTO/Transforms/InsertSync/SyncCodegen.cpp @@ -279,10 +279,14 @@ void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter, Operation *op, SyncOperation *sync, bool beforeInsert) { - // 注意:GetBufferSelected 可能需要在插入 Set/Wait 之前调用,以确保 SSA 顺序 - // 但这里只是获取 Value,不影响 InsertionPoint 的设定 + // Multi-buffer sync: select event id at runtime (e.g. ping-pong in loop). Value bufferSelected = GetBufferSelected(rewriter, op, sync); - (void)bufferSelected; + if (!bufferSelected) { + // Fallback to a fixed event id to avoid crashing on malformed IR. + // This should not happen for well-formed multi-buffer sync operations. + bufferSelected = + rewriter.create(op->getLoc(), sync->eventIds[0]); + } // [Fix] Terminator 强制前置插入 if (beforeInsert || op->hasTrait()) { @@ -293,19 +297,13 @@ void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter, auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe()); auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe()); - auto eventId = getEventAttr(rewriter, sync->eventIds[0]); // 注意:MultiBuffer可能需要特殊处理Attr - - // 这里假设 SetFlagOp/WaitFlagOp 支持动态 Value 作为 EventID,或者您有特殊的 Op - // 如果 PTO 定义只支持 Attribute,那么上面的 GetBufferSelected 逻辑需要配合修改 Op 定义 - // 假设目前的 Op 定义如下: + if (sync->isSyncWaitType()) { - // 假设 WaitFlagOp 有支持 Value eventId 的重载或变体 - // 如果没有,这行代码可能需要调整。但在您之前的 Double Buffer 测试中,看起来它是工作的? - // 或者您是否使用了 UpdateFlagOp (带 Value)? - // 这里保持原样,只修改 InsertionPoint - rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); + rewriter.create(op->getLoc(), srcPipe, dstPipe, + bufferSelected); } else { - rewriter.create(op->getLoc(), srcPipe, dstPipe, eventId); + rewriter.create(op->getLoc(), srcPipe, dstPipe, + bufferSelected); } } diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 1275895c..56e0d365 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -3568,6 +3568,58 @@ struct PTOWaitFlagToEmitC : public OpConversionPattern { } }; +struct PTOSetFlagDynToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFlagDynOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + + const std::string srcTok = pipeTokFromPipeAttr(op.getSrcPipeAttr()); + const std::string dstTok = pipeTokFromPipeAttr(op.getDstPipeAttr()); + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + // The index-typed integer refers to operand #0 (event id). + rewriter.getIndexAttr(0), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{adaptor.getEventId()}); + return success(); + } +}; + +struct PTOWaitFlagDynToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::WaitFlagDynOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + + const std::string srcTok = pipeTokFromPipeAttr(op.getSrcPipeAttr()); + const std::string dstTok = pipeTokFromPipeAttr(op.getDstPipeAttr()); + auto argsAttr = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, srcTok), + emitc::OpaqueAttr::get(ctx, dstTok), + // The index-typed integer refers to operand #0 (event id). + rewriter.getIndexAttr(0), + }); + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "wait_flag", + /*args=*/argsAttr, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{adaptor.getEventId()}); + return success(); + } +}; + struct PTOGetBufToEmitC : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -6923,9 +6975,11 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); diff --git a/test/samples/InjectSync/test_inject_sync_multibuf_loop.py b/test/samples/InjectSync/test_inject_sync_multibuf_loop.py new file mode 100644 index 00000000..1f154c6d --- /dev/null +++ b/test/samples/InjectSync/test_inject_sync_multibuf_loop.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +from mlir.ir import ( + Context, + Location, + Module, + InsertionPoint, + F32Type, + IndexType, +) +from mlir.dialects import func, arith, scf, pto +from mlir.dialects.arith import CmpIPredicate + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + + with Location.unknown(ctx): + m = Module.create() + f32 = F32Type.get(ctx) + idx = IndexType.get(ctx) + + # A minimal ping-pong (double-buffer) loop. + # + # This triggers multi-buffer synchronization insertion on the loop + # back-edge (MTE3 -> MTE2) when `--enable-insert-sync` is enabled: + # the inserted sync needs 2 event IDs (ping-pong) because the + # dependency touches local memory (VEC/UB), so the event id is + # selected by `iv % 2`. + vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx) + bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx) + sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx) + pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx) + + fractal_ab_size = pto.TileConfig.fractalABSize + cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx) + + ptr_f32 = pto.PtrType.get(f32, ctx) + tv2_f32 = pto.TensorViewType.get(2, f32, ctx) + tile_view_32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx) + ws_type = pto.TileBufType.get([32, 64], f32, vec, [32, 32], cfg, ctx) + + fn_ty = func.FunctionType.get([ptr_f32, ptr_f32], []) + with InsertionPoint(m.body): + fn = func.FuncOp("test_inject_sync_multibuf_loop", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + c0 = arith.ConstantOp(idx, 0).result + c1 = arith.ConstantOp(idx, 1).result + c2 = arith.ConstantOp(idx, 2).result + c4 = arith.ConstantOp(idx, 4).result + c32 = arith.ConstantOp(idx, 32).result + + src_ptr, dst_ptr = entry.arguments + + tv_src = pto.MakeTensorViewOp(tv2_f32, src_ptr, [c32, c32], [c32, c1]).result + tv_dst = pto.MakeTensorViewOp(tv2_f32, dst_ptr, [c32, c32], [c32, c1]).result + sv_src = pto.PartitionViewOp(tile_view_32, tv_src, offsets=[c0, c0], sizes=[c32, c32]).result + sv_dst = pto.PartitionViewOp(tile_view_32, tv_dst, offsets=[c0, c0], sizes=[c32, c32]).result + + # Allocate a single workspace and create two non-overlapping + # 32x32 subsets (ping/pong) to model double buffering. + workspace = pto.AllocTileOp(ws_type).result + ping = pto.SubsetOp(workspace, [c0, c0], sizes=[32, 32]).result + pong = pto.SubsetOp(workspace, [c0, c32], sizes=[32, 32]).result + + loop = scf.ForOp(c0, c4, c1, []) + with InsertionPoint(loop.body): + iv = loop.induction_variable + + parity = arith.RemUIOp(iv, c2).result + is_even = arith.CmpIOp(CmpIPredicate.eq, parity, c0).result + ifop = scf.IfOp(is_even, [], hasElse=True) + with InsertionPoint(ifop.then_block): + pto.TLoadOp(None, sv_src, ping) + pto.TStoreOp(None, ping, sv_dst) + scf.YieldOp([]) + with InsertionPoint(ifop.else_block): + pto.TLoadOp(None, sv_src, pong) + pto.TStoreOp(None, pong, sv_dst) + scf.YieldOp([]) + + scf.YieldOp([]) + + func.ReturnOp([]) + + m.operation.verify() + return m + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index 80098379..8202f40f 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -268,6 +268,17 @@ process_one_dir() { fi fi + # Regression guard: multi-buffer (ping-pong) sync must use the selected + # event id in the generated C++ (dynamic event id), not a fixed EVENT_ID0. + if [[ "$base" == "test_inject_sync_multibuf_loop" ]]; then + if ! grep -Eq "wait_flag\\(PIPE_MTE3,[[:space:]]*PIPE_MTE2,[[:space:]]*(\\([[:alnum:]_]+\\)[[:space:]]*)?v[0-9]+\\)" "$cpp" || \ + ! grep -Eq "set_flag\\(PIPE_MTE3,[[:space:]]*PIPE_MTE2,[[:space:]]*(\\([[:alnum:]_]+\\)[[:space:]]*)?v[0-9]+\\)" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing dynamic event id for PIPE_MTE3->PIPE_MTE2 multi-buffer sync" + overall=1 + continue + fi + fi + # Regression guard for issue #117: vector mask must be reset for each # `pto.section.vector` region to avoid cross-kernel state leakage. # Use an existing sample (Complex/cv_region.py) that contains a vector section. @@ -368,6 +379,17 @@ process_one_dir() { fi fi + # Regression guard: multi-buffer (ping-pong) sync must use the selected + # event id in the generated C++ (dynamic event id), not a fixed EVENT_ID0. + if [[ "$base" == "test_inject_sync_loop" ]]; then + if ! grep -Eq "wait_flag\\(PIPE_MTE3,[[:space:]]*PIPE_MTE2,[[:space:]]*(\\([[:alnum:]_]+\\)[[:space:]]*)?v[0-9]+\\)" "$cpp" || \ + ! grep -Eq "set_flag\\(PIPE_MTE3,[[:space:]]*PIPE_MTE2,[[:space:]]*(\\([[:alnum:]_]+\\)[[:space:]]*)?v[0-9]+\\)" "$cpp"; then + echo -e "${A}(${base}.pto)\tFAIL\tmissing dynamic event id for PIPE_MTE3->PIPE_MTE2 multi-buffer sync" + overall=1 + continue + fi + fi + echo -e "${A}(${base}.pto)\tOK\tgenerated: $(basename "$cpp")" done fi