From 93a48cbaebf487877000c9f5fa55245c18b0c7e7 Mon Sep 17 00:00:00 2001 From: TaoTao-real Date: Fri, 13 Mar 2026 12:10:18 +0800 Subject: [PATCH] [Sync] Align inter-core sync lowering for A3/A5 and add regressions --- lib/PTO/Transforms/PTOToEmitC.cpp | 181 ++++++++++++++++-- python/pto/dialects/pto.py | 37 ++++ test/samples/Sync/test_intercore_sync_a3.py | 34 ++++ .../test_intercore_sync_a3_missing_setffts.py | 29 +++ test/samples/Sync/test_intercore_sync_a5.py | 29 +++ test/samples/runop.sh | 74 ++++++- 6 files changed, 363 insertions(+), 21 deletions(-) create mode 100644 test/samples/Sync/test_intercore_sync_a3.py create mode 100644 test/samples/Sync/test_intercore_sync_a3_missing_setffts.py create mode 100644 test/samples/Sync/test_intercore_sync_a5.py diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index 939287dd..950c24cd 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -303,6 +303,7 @@ static constexpr unsigned kPTOIndexBitWidth = 32; // keep consistent with IndexType conversion // Forward declarations (definitions below). +static inline std::string pipeTokFromPipeAttr(mlir::pto::PipeAttr a); static emitc::OpaqueType getSignedIntOpaqueType(MLIRContext *ctx, unsigned bitWidth); static emitc::OpaqueType getUnsignedIntOpaqueType(MLIRContext *ctx, @@ -322,6 +323,104 @@ static Value castSignlessIntToUnsignedSameWidth(ConversionPatternRewriter &rewri Location loc, Value v, unsigned bitWidth); +static bool isSetFFTsPointerLikeType(Type ty) { + if (isa(ty)) + return true; + if (auto opaqueTy = dyn_cast(ty)) + return opaqueTy.getValue().ends_with("*"); + return false; +} + +struct InterCoreSyncCallDesc { + const char *callee = nullptr; + ArrayAttr args; + SmallVector operands; +}; + +static InterCoreSyncCallDesc buildInterCoreSyncSetCall( + ConversionPatternRewriter &rewriter, Location loc, PTOArch targetArch, + pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + + if (targetArch == PTOArch::A3) { + auto i32Ty = emitc::OpaqueType::get(ctx, "int32_t"); + Value eventVal = + makeEmitCIntConstant(rewriter, loc, i32Ty, eventIdAttr.getInt()); + + auto msgTy = emitc::OpaqueType::get(ctx, "uint16_t"); + auto msgArgs = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, "FFTS_MODE_VAL"), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + Value msgVal = + rewriter + .create(loc, msgTy, "getFFTSMsg", + /*args=*/msgArgs, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{eventVal}) + .getResult(0); + + InterCoreSyncCallDesc desc; + desc.callee = "ffts_cross_core_sync"; + desc.args = rewriter.getArrayAttr({ + emitc::OpaqueAttr::get(ctx, pipeTok), + IntegerAttr::get(IndexType::get(ctx), 0), + }); + desc.operands.push_back(msgVal); + return desc; + } + + InterCoreSyncCallDesc desc; + desc.callee = "set_intra_block"; + desc.args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); + return desc; +} + +static InterCoreSyncCallDesc buildInterCoreSyncWaitCall( + ConversionPatternRewriter &rewriter, PTOArch targetArch, + pto::PipeAttr pipeAttr, IntegerAttr eventIdAttr) { + auto *ctx = rewriter.getContext(); + std::string pipeTok = pipeTokFromPipeAttr(pipeAttr); + + InterCoreSyncCallDesc desc; + if (targetArch == PTOArch::A3) { + desc.callee = "wait_flag_dev"; + desc.args = rewriter.getArrayAttr({eventIdAttr}); + return desc; + } + + desc.callee = "wait_intra_block"; + desc.args = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, pipeTok), eventIdAttr}); + return desc; +} + +static bool hasInterCoreSyncOp(func::FuncOp func) { + bool found = false; + func.walk([&](Operation *op) { + if (isa(op)) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + +static bool hasSetFFTsOp(func::FuncOp func) { + bool found = false; + func.walk([&](Operation *op) { + if (isa(op)) { + found = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return found; +} + //===----------------------------------------------------------------------===// // Arith -> EmitC (full dialect coverage for scalar ops) //===----------------------------------------------------------------------===// @@ -3686,6 +3785,41 @@ struct PTORlsBufToEmitC : public OpConversionPattern { } }; +struct PTOSetFFTsToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(mlir::pto::SetFFTsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = rewriter.getContext(); + auto loc = op.getLoc(); + + Value fftsAddr = peelUnrealized(adaptor.getFfts()); + auto u64Ty = emitc::OpaqueType::get(ctx, "uint64_t"); + + if (isSetFFTsPointerLikeType(fftsAddr.getType())) { + auto castTyAttr = + rewriter.getArrayAttr({emitc::OpaqueAttr::get(ctx, "uint64_t")}); + fftsAddr = + rewriter + .create(loc, u64Ty, "reinterpret_cast", + /*args=*/ArrayAttr{}, + /*templateArgs=*/castTyAttr, + /*operands=*/ValueRange{fftsAddr}) + .getResult(0); + } else if (fftsAddr.getType() != u64Ty) { + fftsAddr = + rewriter.create(loc, u64Ty, fftsAddr).getResult(); + } + + rewriter.replaceOpWithNewOp( + op, TypeRange{}, "set_ffts_base_addr", + /*args=*/ArrayAttr{}, + /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{fftsAddr}); + return success(); + } +}; + struct PTOSyncSetToEmitC : public OpConversionPattern { PTOSyncSetToEmitC(TypeConverter &typeConverter, MLIRContext *ctx, PTOArch targetArch) @@ -3696,19 +3830,13 @@ struct PTOSyncSetToEmitC : public OpConversionPattern { matchAndRewrite(mlir::pto::SyncSetOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { (void)adaptor; - auto *ctx = rewriter.getContext(); auto loc = op->getLoc(); - - std::string pipeTok = pipeTokFromPipeAttr(op.getPipe()); - auto argsAttr = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, pipeTok), op.getEventIdAttr()}); - const char *kSyncSetCallee = (targetArch == PTOArch::A3) - ? "ffts_cross_core_sync" - : "set_intra_block"; - rewriter.create(loc, TypeRange{}, kSyncSetCallee, - /*args=*/argsAttr, + auto desc = buildInterCoreSyncSetCall(rewriter, loc, targetArch, op.getPipe(), + op.getEventIdAttr()); + rewriter.create(loc, TypeRange{}, desc.callee, + /*args=*/desc.args, /*templateArgs=*/ArrayAttr{}, - /*operands=*/ValueRange{}); + /*operands=*/desc.operands); rewriter.eraseOp(op); return success(); @@ -3727,16 +3855,11 @@ struct PTOSyncWaitToEmitC : public OpConversionPattern { matchAndRewrite(mlir::pto::SyncWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { (void)adaptor; - auto *ctx = rewriter.getContext(); auto loc = op->getLoc(); - - std::string pipeTok = pipeTokFromPipeAttr(op.getPipe()); - auto argsAttr = rewriter.getArrayAttr( - {emitc::OpaqueAttr::get(ctx, pipeTok), op.getEventIdAttr()}); - const char *kSyncWaitCallee = - (targetArch == PTOArch::A3) ? "wait_flag_dev" : "wait_intra_block"; - rewriter.create(loc, TypeRange{}, kSyncWaitCallee, - argsAttr, ArrayAttr{}, ValueRange{}); + auto desc = buildInterCoreSyncWaitCall(rewriter, targetArch, op.getPipe(), + op.getEventIdAttr()); + rewriter.create(loc, TypeRange{}, desc.callee, + desc.args, ArrayAttr{}, desc.operands); rewriter.eraseOp(op); return success(); @@ -7258,6 +7381,7 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); @@ -7456,6 +7580,23 @@ struct EmitPTOManualPass MLIRContext *ctx = &getContext(); ModuleOp mop = getOperation(); + // A3 requires explicit FFTS base setup for inter-core sync ops. + if (targetArch == PTOArch::A3) { + bool hasMissingSetFFTs = false; + for (auto func : mop.getOps()) { + if (!hasInterCoreSyncOp(func)) + continue; + if (hasSetFFTsOp(func)) + continue; + hasMissingSetFFTs = true; + func.emitError() + << "A3 inter-core sync requires explicit `pto.set_ffts` in the " + "same function when using `pto.sync.set`/`pto.sync.wait`"; + } + if (hasMissingSetFFTs) + return signalPassFailure(); + } + // 1. 插入头文件 auto loc = mop->getLoc(); OpBuilder builder(ctx); diff --git a/python/pto/dialects/pto.py b/python/pto/dialects/pto.py index eb51b2f2..78183c1c 100644 --- a/python/pto/dialects/pto.py +++ b/python/pto/dialects/pto.py @@ -79,6 +79,8 @@ def _load_local_pto_ext(): "TileConfig", # High-level sync helpers "record_event", "wait_event", "barrier", + # Inter-core sync helpers + "sync_set", "sync_wait", "set_ffts", # A5 buffer-id sync helpers "get_buf", "rls_buf", # Scalar pointer helpers @@ -168,6 +170,41 @@ def barrier(op, *, loc=None, ip=None): # Otherwise fall back to low-level barrier expecting PipeAttr return _pto_ops_gen.barrier(op, loc=loc, ip=ip) +# ----------------------------------------------------------------------------- +# Inter-core sync helpers (pto.sync.set / pto.sync.wait / pto.set_ffts) +# ----------------------------------------------------------------------------- +def sync_set(pipe, event_id, *, loc=None, ip=None): + ctx = loc.context if loc else _ods_ir.Context.current + return _ods_ir.Operation.create( + "pto.sync.set", + attributes={ + "pipe": _ensure_pipe_attr(pipe, ctx), + "event_id": _ensure_i32_attr(event_id, "event_id", ctx), + }, + loc=loc, + ip=ip, + ) + +def sync_wait(pipe, event_id, *, loc=None, ip=None): + ctx = loc.context if loc else _ods_ir.Context.current + return _ods_ir.Operation.create( + "pto.sync.wait", + attributes={ + "pipe": _ensure_pipe_attr(pipe, ctx), + "event_id": _ensure_i32_attr(event_id, "event_id", ctx), + }, + loc=loc, + ip=ip, + ) + +def set_ffts(ffts, *, loc=None, ip=None): + return _ods_ir.Operation.create( + "pto.set_ffts", + operands=[_pto_ops_gen._get_op_result_or_value(ffts)], + loc=loc, + ip=ip, + ) + # ----------------------------------------------------------------------------- # A5 buffer-id sync helpers # ----------------------------------------------------------------------------- diff --git a/test/samples/Sync/test_intercore_sync_a3.py b/test/samples/Sync/test_intercore_sync_a3.py new file mode 100644 index 00000000..04eaf2ee --- /dev/null +++ b/test/samples/Sync/test_intercore_sync_a3.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +from mlir.ir import Context, InsertionPoint, IntegerType, Location, MemRefType, Module +from mlir.dialects import func, pto + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + with Location.unknown(ctx): + module = Module.create() + + i64 = IntegerType.get_signless(64, ctx) + # Minimal valid memref operand for pto.set_ffts verifier (i64 element). + ffts_ty = MemRefType.get([1], i64) + fn_ty = func.FunctionType.get([ffts_ty], []) + + with InsertionPoint(module.body): + fn = func.FuncOp("test_intercore_sync_a3", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + pipe_fix = pto.PipeAttr.get(pto.PIPE.PIPE_FIX, ctx) + pipe_v = pto.PipeAttr.get(pto.PIPE.PIPE_V, ctx) + pto.set_ffts(entry.arguments[0]) + pto.sync_set(pipe_fix, 3) + pto.sync_wait(pipe_v, 3) + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Sync/test_intercore_sync_a3_missing_setffts.py b/test/samples/Sync/test_intercore_sync_a3_missing_setffts.py new file mode 100644 index 00000000..e33ea0ac --- /dev/null +++ b/test/samples/Sync/test_intercore_sync_a3_missing_setffts.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +from mlir.ir import Context, InsertionPoint, Location, Module +from mlir.dialects import func, pto + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + with Location.unknown(ctx): + module = Module.create() + fn_ty = func.FunctionType.get([], []) + + with InsertionPoint(module.body): + fn = func.FuncOp("test_intercore_sync_a3_missing_setffts", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + pipe_fix = pto.PipeAttr.get(pto.PIPE.PIPE_FIX, ctx) + pipe_v = pto.PipeAttr.get(pto.PIPE.PIPE_V, ctx) + pto.sync_set(pipe_fix, 7) + pto.sync_wait(pipe_v, 7) + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/Sync/test_intercore_sync_a5.py b/test/samples/Sync/test_intercore_sync_a5.py new file mode 100644 index 00000000..8c887833 --- /dev/null +++ b/test/samples/Sync/test_intercore_sync_a5.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +from mlir.ir import Context, InsertionPoint, Location, Module +from mlir.dialects import func, pto + + +def build(): + with Context() as ctx: + pto.register_dialect(ctx, load=True) + with Location.unknown(ctx): + module = Module.create() + fn_ty = func.FunctionType.get([], []) + + with InsertionPoint(module.body): + fn = func.FuncOp("test_intercore_sync_a5", fn_ty) + entry = fn.add_entry_block() + + with InsertionPoint(entry): + pipe_fix = pto.PipeAttr.get(pto.PIPE.PIPE_FIX, ctx) + pipe_v = pto.PipeAttr.get(pto.PIPE.PIPE_V, ctx) + pto.sync_set(pipe_fix, 5) + pto.sync_wait(pipe_v, 5) + func.ReturnOp([]) + + module.operation.verify() + return module + + +if __name__ == "__main__": + print(build()) diff --git a/test/samples/runop.sh b/test/samples/runop.sh index ee50e931..33df3978 100755 --- a/test/samples/runop.sh +++ b/test/samples/runop.sh @@ -198,6 +198,20 @@ process_one_dir() { fi fi + # Inter-core sync regression samples are arch-specific. + if [[ "$base" == "test_intercore_sync_a5" && "$(printf '%s' "$target_arch" | tr '[:upper:]' '[:lower:]')" != "a5" ]]; then + echo -e "${A}(${base}.py)\tSKIP\trequires --pto-arch=a5" + continue + fi + if [[ "$base" == "test_intercore_sync_a3" && "$(printf '%s' "$target_arch" | tr '[:upper:]' '[:lower:]')" != "a3" ]]; then + echo -e "${A}(${base}.py)\tSKIP\trequires --pto-arch=a3" + continue + fi + if [[ "$base" == "test_intercore_sync_a3_missing_setffts" && "$(printf '%s' "$target_arch" | tr '[:upper:]' '[:lower:]')" != "a3" ]]; then + echo -e "${A}(${base}.py)\tSKIP\trequires --pto-arch=a3" + continue + fi + # Some samples are expected to fail depending on the selected ptoas flags. # # alloc_tile_addr.py uses `pto.alloc_tile addr=...`, which is only accepted @@ -220,6 +234,9 @@ process_one_dir() { fi [[ $has_level3 -eq 1 ]] || expect_fail=1 fi + if [[ "$base" == "test_intercore_sync_a3_missing_setffts" && "$(printf '%s' "$target_arch" | tr '[:upper:]' '[:lower:]')" == "a3" ]]; then + expect_fail=1 + fi mlir="${out_subdir}/${base}-pto-ir.pto" cpp="${out_subdir}/${base}-pto.cpp" @@ -261,8 +278,16 @@ process_one_dir() { # Write output via -o to avoid mixing debug prints with generated C++. local -a ptoas_cmd=("${ptoas_cmd_base[@]}" "$pto_input" -o "$cpp") - if ! "${ptoas_cmd[@]}" >/dev/null 2>&1; then + local ptoas_log="${out_subdir}/${base}-ptoas.log" + if ! "${ptoas_cmd[@]}" >"${ptoas_log}" 2>&1; then if [[ $expect_fail -eq 1 ]]; then + if [[ "$base" == "test_intercore_sync_a3_missing_setffts" ]]; then + if ! grep -Eq "A3 inter-core sync requires explicit .*pto.set_ffts" "${ptoas_log}"; then + echo -e "${A}(${base}.py)\tFAIL\texpected missing-set_ffts diagnostic not found" + overall=1 + continue + fi + fi echo -e "${A}(${base}.py)\tXFAIL\tptoas failed as expected" continue fi @@ -317,6 +342,53 @@ process_one_dir() { fi fi + # Inter-core sync regression: A3/A5 must lower pto.sync.set/wait to + # architecture-specific ISA interfaces. + if [[ "$base" == "test_intercore_sync_a3" ]]; then + if ! grep -Fq "set_ffts_base_addr(" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing set_ffts_base_addr() lowering" + overall=1 + continue + fi + if ! grep -Fq "ffts_cross_core_sync(PIPE_FIX" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing A3 sync.set lowering to ffts_cross_core_sync" + overall=1 + continue + fi + if ! grep -Fq "getFFTSMsg(FFTS_MODE_VAL," "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing A3 getFFTSMsg(FFTS_MODE_VAL, ...) encoding" + overall=1 + continue + fi + if ! grep -Fq "wait_flag_dev(3)" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing A3 sync.wait lowering to wait_flag_dev(event_id)" + overall=1 + continue + fi + if grep -Fq "wait_flag_dev(PIPE_" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tunexpected wait_flag_dev(pipe, event_id) lowering on A3" + overall=1 + continue + fi + fi + if [[ "$base" == "test_intercore_sync_a5" ]]; then + if ! grep -Fq "set_intra_block(PIPE_FIX, 5)" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing A5 sync.set lowering to set_intra_block" + overall=1 + continue + fi + if ! grep -Fq "wait_intra_block(PIPE_V, 5)" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tmissing A5 sync.wait lowering to wait_intra_block" + overall=1 + continue + fi + if grep -Fq "ffts_cross_core_sync(" "$cpp" || grep -Fq "wait_flag_dev(" "$cpp"; then + echo -e "${A}(${base}.py)\tFAIL\tunexpected A3-style inter-core sync call in A5 output" + overall=1 + continue + fi + fi + # Regression guard for issue #185: barrier_sync must support op types # beyond TMATMUL/TVEC and lower to the expected per-pipe barrier. if [[ "$base" == "test_barrier_sync" ]]; then