Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,45 @@ def WaitFlagOp : PTO_Op<"wait_flag"> {
}];
}

def SetFlagDynOp : PTO_Op<"set_flag_dyn"> {
let summary = "Set synchronization flag between pipes (dynamic event id)";
let description = [{
`pto.set_flag_dyn` is the dynamic form of `pto.set_flag`. It is intended
for multi-buffer synchronization patterns where the event id is selected
at runtime (e.g. ping-pong between two event ids in a loop).

The event id value must be in the valid backend range (typically 0..7).
}];

let arguments = (ins
PTO_PipeAttr:$src_pipe,
PTO_PipeAttr:$dst_pipe,
Index:$event_id
);
let results = (outs);
let assemblyFormat = [{
`[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict
}];
}

def WaitFlagDynOp : PTO_Op<"wait_flag_dyn"> {
let summary = "Wait for synchronization flag (dynamic event id)";
let description = [{
`pto.wait_flag_dyn` is the dynamic form of `pto.wait_flag`. See
`pto.set_flag_dyn` for the intended use cases.
}];

let arguments = (ins
PTO_PipeAttr:$src_pipe,
PTO_PipeAttr:$dst_pipe,
Index:$event_id
);
let results = (outs);
let assemblyFormat = [{
`[` $src_pipe `,` $dst_pipe `,` $event_id `]` attr-dict
}];
}

//===----------------------------------------------------------------------===//
// Buffer-ID Synchronization (A5)
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ void PTOIRTranslator::RecursionIR(Region *region) {
else if (auto subViewOp = dyn_cast<pto::PartitionViewOp>(op)) {
UpdateAliasBufferInfo(subViewOp.getResult(), subViewOp.getSource());
}
else if (auto subsetOp = dyn_cast<pto::SubsetOp>(op)) {
UpdateAliasBufferInfo(subsetOp.getResult(), subsetOp.getSource());
}
else if (auto memrefSubView = dyn_cast<memref::SubViewOp>(op)) {
UpdateAliasBufferInfo(memrefSubView.getResult(), memrefSubView.getSource());
}
Expand Down
3 changes: 2 additions & 1 deletion lib/PTO/Transforms/InsertSync/PTOInsertSync.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ struct PTOInsertSyncPass : public mlir::pto::impl::PTOInsertSyncBase<PTOInsertSy
//
bool hasExplicitSync = false;
func.walk([&](Operation *op) {
if (isa<pto::SetFlagOp, pto::WaitFlagOp, pto::RecordEventOp,
if (isa<pto::SetFlagOp, pto::WaitFlagOp, pto::SetFlagDynOp,
pto::WaitFlagDynOp, pto::RecordEventOp,
pto::WaitEventOp>(op)) {
hasExplicitSync = true;
return WalkResult::interrupt();
Expand Down
26 changes: 12 additions & 14 deletions lib/PTO/Transforms/InsertSync/SyncCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,14 @@ void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter,
Operation *op,
SyncOperation *sync,
bool beforeInsert) {
// 注意:GetBufferSelected 可能需要在插入 Set/Wait 之前调用,以确保 SSA 顺序
// 但这里只是获取 Value,不影响 InsertionPoint 的设定
// Multi-buffer sync: select event id at runtime (e.g. ping-pong in loop).
Value bufferSelected = GetBufferSelected(rewriter, op, sync);
(void)bufferSelected;
if (!bufferSelected) {
// Fallback to a fixed event id to avoid crashing on malformed IR.
// This should not happen for well-formed multi-buffer sync operations.
bufferSelected =
rewriter.create<arith::ConstantIndexOp>(op->getLoc(), sync->eventIds[0]);
}

// [Fix] Terminator 强制前置插入
if (beforeInsert || op->hasTrait<OpTrait::IsTerminator>()) {
Expand All @@ -293,19 +297,13 @@ void SyncCodegen::CreateSetWaitOpForMultiBuffer(IRRewriter &rewriter,

auto srcPipe = getPipeAttr(rewriter, sync->GetActualSrcPipe());
auto dstPipe = getPipeAttr(rewriter, sync->GetActualDstPipe());
auto eventId = getEventAttr(rewriter, sync->eventIds[0]); // 注意:MultiBuffer可能需要特殊处理Attr

// 这里假设 SetFlagOp/WaitFlagOp 支持动态 Value 作为 EventID,或者您有特殊的 Op
// 如果 PTO 定义只支持 Attribute,那么上面的 GetBufferSelected 逻辑需要配合修改 Op 定义
// 假设目前的 Op 定义如下:

if (sync->isSyncWaitType()) {
// 假设 WaitFlagOp 有支持 Value eventId 的重载或变体
// 如果没有,这行代码可能需要调整。但在您之前的 Double Buffer 测试中,看起来它是工作的?
// 或者您是否使用了 UpdateFlagOp (带 Value)?
// 这里保持原样,只修改 InsertionPoint
rewriter.create<pto::WaitFlagOp>(op->getLoc(), srcPipe, dstPipe, eventId);
rewriter.create<pto::WaitFlagDynOp>(op->getLoc(), srcPipe, dstPipe,
bufferSelected);
} else {
rewriter.create<pto::SetFlagOp>(op->getLoc(), srcPipe, dstPipe, eventId);
rewriter.create<pto::SetFlagDynOp>(op->getLoc(), srcPipe, dstPipe,
bufferSelected);
}
}

Expand Down
54 changes: 54 additions & 0 deletions lib/PTO/Transforms/PTOToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3568,6 +3568,58 @@ struct PTOWaitFlagToEmitC : public OpConversionPattern<mlir::pto::WaitFlagOp> {
}
};

struct PTOSetFlagDynToEmitC
: public OpConversionPattern<mlir::pto::SetFlagDynOp> {
using OpConversionPattern<mlir::pto::SetFlagDynOp>::OpConversionPattern;

LogicalResult matchAndRewrite(mlir::pto::SetFlagDynOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *ctx = rewriter.getContext();

const std::string srcTok = pipeTokFromPipeAttr(op.getSrcPipeAttr());
const std::string dstTok = pipeTokFromPipeAttr(op.getDstPipeAttr());
auto argsAttr = rewriter.getArrayAttr({
emitc::OpaqueAttr::get(ctx, srcTok),
emitc::OpaqueAttr::get(ctx, dstTok),
// The index-typed integer refers to operand #0 (event id).
rewriter.getIndexAttr(0),
});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
op, TypeRange{}, "set_flag",
/*args=*/argsAttr,
/*templateArgs=*/ArrayAttr{},
/*operands=*/ValueRange{adaptor.getEventId()});
return success();
}
};

struct PTOWaitFlagDynToEmitC
: public OpConversionPattern<mlir::pto::WaitFlagDynOp> {
using OpConversionPattern<mlir::pto::WaitFlagDynOp>::OpConversionPattern;

LogicalResult matchAndRewrite(mlir::pto::WaitFlagDynOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *ctx = rewriter.getContext();

const std::string srcTok = pipeTokFromPipeAttr(op.getSrcPipeAttr());
const std::string dstTok = pipeTokFromPipeAttr(op.getDstPipeAttr());
auto argsAttr = rewriter.getArrayAttr({
emitc::OpaqueAttr::get(ctx, srcTok),
emitc::OpaqueAttr::get(ctx, dstTok),
// The index-typed integer refers to operand #0 (event id).
rewriter.getIndexAttr(0),
});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
op, TypeRange{}, "wait_flag",
/*args=*/argsAttr,
/*templateArgs=*/ArrayAttr{},
/*operands=*/ValueRange{adaptor.getEventId()});
return success();
}
};

struct PTOGetBufToEmitC : public OpConversionPattern<mlir::pto::GetBufOp> {
using OpConversionPattern<mlir::pto::GetBufOp>::OpConversionPattern;

Expand Down Expand Up @@ -6923,9 +6975,11 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns,
patterns.add<ArithCmpIToEmitC>(typeConverter, ctx);
patterns.add<PTOBindTileToEmitC>(typeConverter, ctx);
patterns.add<PTOSetFlagToEmitC>(typeConverter, ctx);
patterns.add<PTOSetFlagDynToEmitC>(typeConverter, ctx);
patterns.add<PTOSubSCToEmitC>(typeConverter, ctx);
patterns.add<PTOSubCSToEmitC>(typeConverter, ctx);
patterns.add<PTOWaitFlagToEmitC>(typeConverter, ctx);
patterns.add<PTOWaitFlagDynToEmitC>(typeConverter, ctx);
patterns.add<PTOGetBufToEmitC>(typeConverter, ctx);
patterns.add<PTORlsBufToEmitC>(typeConverter, ctx);
patterns.add<PTOXORSToEmitC>(typeConverter, ctx);
Expand Down
93 changes: 93 additions & 0 deletions test/samples/InjectSync/test_inject_sync_multibuf_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#!/usr/bin/env python3
from mlir.ir import (
Context,
Location,
Module,
InsertionPoint,
F32Type,
IndexType,
)
from mlir.dialects import func, arith, scf, pto
from mlir.dialects.arith import CmpIPredicate


def build():
with Context() as ctx:
pto.register_dialect(ctx, load=True)

with Location.unknown(ctx):
m = Module.create()
f32 = F32Type.get(ctx)
idx = IndexType.get(ctx)

# A minimal ping-pong (double-buffer) loop.
#
# This triggers multi-buffer synchronization insertion on the loop
# back-edge (MTE3 -> MTE2) when `--enable-insert-sync` is enabled:
# the inserted sync needs 2 event IDs (ping-pong) because the
# dependency touches local memory (VEC/UB), so the event id is
# selected by `iv % 2`.
vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx)
bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx)
sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx)
pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx)

fractal_ab_size = pto.TileConfig.fractalABSize
cfg = pto.TileBufConfigAttr.get(bl, sl, fractal_ab_size, pd, ctx)

ptr_f32 = pto.PtrType.get(f32, ctx)
tv2_f32 = pto.TensorViewType.get(2, f32, ctx)
tile_view_32 = pto.PartitionTensorViewType.get([32, 32], f32, ctx)
ws_type = pto.TileBufType.get([32, 64], f32, vec, [32, 32], cfg, ctx)

fn_ty = func.FunctionType.get([ptr_f32, ptr_f32], [])
with InsertionPoint(m.body):
fn = func.FuncOp("test_inject_sync_multibuf_loop", fn_ty)
entry = fn.add_entry_block()

with InsertionPoint(entry):
c0 = arith.ConstantOp(idx, 0).result
c1 = arith.ConstantOp(idx, 1).result
c2 = arith.ConstantOp(idx, 2).result
c4 = arith.ConstantOp(idx, 4).result
c32 = arith.ConstantOp(idx, 32).result

src_ptr, dst_ptr = entry.arguments

tv_src = pto.MakeTensorViewOp(tv2_f32, src_ptr, [c32, c32], [c32, c1]).result
tv_dst = pto.MakeTensorViewOp(tv2_f32, dst_ptr, [c32, c32], [c32, c1]).result
sv_src = pto.PartitionViewOp(tile_view_32, tv_src, offsets=[c0, c0], sizes=[c32, c32]).result
sv_dst = pto.PartitionViewOp(tile_view_32, tv_dst, offsets=[c0, c0], sizes=[c32, c32]).result

# Allocate a single workspace and create two non-overlapping
# 32x32 subsets (ping/pong) to model double buffering.
workspace = pto.AllocTileOp(ws_type).result
ping = pto.SubsetOp(workspace, [c0, c0], sizes=[32, 32]).result
pong = pto.SubsetOp(workspace, [c0, c32], sizes=[32, 32]).result

loop = scf.ForOp(c0, c4, c1, [])
with InsertionPoint(loop.body):
iv = loop.induction_variable

parity = arith.RemUIOp(iv, c2).result
is_even = arith.CmpIOp(CmpIPredicate.eq, parity, c0).result
ifop = scf.IfOp(is_even, [], hasElse=True)
with InsertionPoint(ifop.then_block):
pto.TLoadOp(None, sv_src, ping)
pto.TStoreOp(None, ping, sv_dst)
scf.YieldOp([])
with InsertionPoint(ifop.else_block):
pto.TLoadOp(None, sv_src, pong)
pto.TStoreOp(None, pong, sv_dst)
scf.YieldOp([])

scf.YieldOp([])

func.ReturnOp([])

m.operation.verify()
return m


if __name__ == "__main__":
print(build())
22 changes: 22 additions & 0 deletions test/samples/runop.sh
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,17 @@ process_one_dir() {
fi
fi

# Regression guard: multi-buffer (ping-pong) sync must use the selected
# event id in the generated C++ (dynamic event id), not a fixed EVENT_ID0.
if [[ "$base" == "test_inject_sync_multibuf_loop" ]]; then
if ! grep -Eq "wait_flag\\(PIPE_MTE3,[[:space:]]*PIPE_MTE2,[[:space:]]*(\\([[:alnum:]_]+\\)[[:space:]]*)?v[0-9]+\\)" "$cpp" || \
! grep -Eq "set_flag\\(PIPE_MTE3,[[:space:]]*PIPE_MTE2,[[:space:]]*(\\([[:alnum:]_]+\\)[[:space:]]*)?v[0-9]+\\)" "$cpp"; then
echo -e "${A}(${base}.py)\tFAIL\tmissing dynamic event id for PIPE_MTE3->PIPE_MTE2 multi-buffer sync"
overall=1
continue
fi
fi

# Regression guard for issue #117: vector mask must be reset for each
# `pto.section.vector` region to avoid cross-kernel state leakage.
# Use an existing sample (Complex/cv_region.py) that contains a vector section.
Expand Down Expand Up @@ -368,6 +379,17 @@ process_one_dir() {
fi
fi

# Regression guard: multi-buffer (ping-pong) sync must use the selected
# event id in the generated C++ (dynamic event id), not a fixed EVENT_ID0.
if [[ "$base" == "test_inject_sync_loop" ]]; then
if ! grep -Eq "wait_flag\\(PIPE_MTE3,[[:space:]]*PIPE_MTE2,[[:space:]]*(\\([[:alnum:]_]+\\)[[:space:]]*)?v[0-9]+\\)" "$cpp" || \
! grep -Eq "set_flag\\(PIPE_MTE3,[[:space:]]*PIPE_MTE2,[[:space:]]*(\\([[:alnum:]_]+\\)[[:space:]]*)?v[0-9]+\\)" "$cpp"; then
echo -e "${A}(${base}.pto)\tFAIL\tmissing dynamic event id for PIPE_MTE3->PIPE_MTE2 multi-buffer sync"
overall=1
continue
fi
fi

echo -e "${A}(${base}.pto)\tOK\tgenerated: $(basename "$cpp")"
done
fi
Expand Down
Loading