From 3dfb0197ee3022b7377ab92f244b4251ec5bad8c Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Sat, 14 Mar 2026 13:19:49 +0800 Subject: [PATCH 1/2] Fix scf.for loop-carried type reconciliation in PTOViewToMemref --- lib/PTO/Transforms/PTOViewToMemref.cpp | 60 +++++++++++++++++++ .../basic/scf_for_tile_iter_arg_reconcile.pto | 29 +++++++++ 2 files changed, 89 insertions(+) create mode 100644 test/basic/scf_for_tile_iter_arg_reconcile.pto diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 9bc0b95b..f40574af 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -449,6 +449,62 @@ static LogicalResult reconcileSCFIfResultTypes(func::FuncOp func) { return success(); } +// Ensure scf.for loop-carried types follow the rewritten init/yield operand +// types. +// +// PTOViewToMemref rewrites tile values to memref inside loop bodies, but the +// scf.for region iter_args and results are not updated automatically. For +// loop-carried values this leaves the op structurally inconsistent: +// init arg : memref<...> +// region arg : !pto.tile_buf<...> +// scf.yield : memref<...> +// scf.for res : !pto.tile_buf<...> +// +// We only reconcile when the external init operand and yielded value already +// agree on the same type; otherwise this pass has not fully lowered the loop +// consistently and we should fail loudly instead of guessing. +static LogicalResult reconcileSCFForResultTypes(func::FuncOp func) { + SmallVector forOps; + func.walk([&](scf::ForOp forOp) { forOps.push_back(forOp); }); + + for (scf::ForOp forOp : forOps) { + if (forOp.getNumRegionIterArgs() == 0) + continue; + + auto yieldOp = + dyn_cast(forOp.getBody()->getTerminator()); + if (!yieldOp) { + forOp.emitError("loop-carried scf.for must end with scf.yield"); + return failure(); + } + + if (yieldOp.getNumOperands() != forOp.getNumRegionIterArgs()) { + forOp.emitError("scf.for iter_arg count does not match yielded values"); + return failure(); + } + + for (unsigned i = 0; i < forOp.getNumRegionIterArgs(); ++i) { + Type initTy = forOp.getInitArgs()[i].getType(); + Type yieldTy = yieldOp.getOperand(i).getType(); + if (initTy != yieldTy) { + forOp.emitError() + << "scf.for loop-carried type mismatch at iter_arg #" << i + << ": init=" << initTy << ", yield=" << yieldTy; + return failure(); + } + + BlockArgument iterArg = forOp.getRegionIterArg(i); + if (iterArg.getType() != initTy) + iterArg.setType(initTy); + + if (forOp.getResult(i).getType() != initTy) + forOp.getResult(i).setType(initTy); + } + } + + return success(); +} + // ============================================================================= // The Pass Implementation // ============================================================================= @@ -2604,6 +2660,10 @@ struct PTOViewToMemrefPass // ------------------------------------------------------------------ // Stage 4: Reconcile control-flow result types // ------------------------------------------------------------------ + if (failed(reconcileSCFForResultTypes(func))) { + signalPassFailure(); + return; + } if (failed(reconcileSCFIfResultTypes(func))) { signalPassFailure(); return; diff --git a/test/basic/scf_for_tile_iter_arg_reconcile.pto b/test/basic/scf_for_tile_iter_arg_reconcile.pto new file mode 100644 index 00000000..0ee5306b --- /dev/null +++ b/test/basic/scf_for_tile_iter_arg_reconcile.pto @@ -0,0 +1,29 @@ +// RUN: ptoas %s 2>&1 1>/dev/null | FileCheck %s + +module { + func.func @for_tile_iter_arg_reconcile() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %zero = arith.constant 0.000000e+00 : f32 + + %acc0 = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + pto.tmuls ins(%acc0, %zero : !pto.tile_buf, f32) outs(%acc0 : !pto.tile_buf) + pto.tmuls ins(%tmp, %zero : !pto.tile_buf, f32) outs(%tmp : !pto.tile_buf) + + %acc = scf.for %i = %c0 to %c2 step %c1 iter_args(%iter = %acc0) -> (!pto.tile_buf) { + pto.tadd ins(%iter, %tmp : !pto.tile_buf, !pto.tile_buf) outs(%tmp : !pto.tile_buf) + scf.yield %tmp : !pto.tile_buf + } + + pto.tmuls ins(%acc, %zero : !pto.tile_buf, f32) outs(%tmp : !pto.tile_buf) + return + } +} + +// CHECK: scf.for +// CHECK-SAME: iter_args(%{{.*}} = %{{.*}}) -> (memref<16x1xf32 +// CHECK: pto.tadd ins(%{{.*}}, %{{.*}} : memref<16x1xf32 +// CHECK: scf.yield %{{.*}} : memref<16x1xf32 +// CHECK-NOT: error: From ee92a6b6a053ff063bafff27a86063aa784773ff Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Sat, 14 Mar 2026 14:01:43 +0800 Subject: [PATCH 2/2] Reconcile scf.if before scf.for in PTOViewToMemref --- lib/PTO/Transforms/PTOViewToMemref.cpp | 4 +-- .../scf_if_for_tile_iter_arg_reconcile.pto | 35 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 test/basic/scf_if_for_tile_iter_arg_reconcile.pto diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index f40574af..4c157bf4 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -2660,11 +2660,11 @@ struct PTOViewToMemrefPass // ------------------------------------------------------------------ // Stage 4: Reconcile control-flow result types // ------------------------------------------------------------------ - if (failed(reconcileSCFForResultTypes(func))) { + if (failed(reconcileSCFIfResultTypes(func))) { signalPassFailure(); return; } - if (failed(reconcileSCFIfResultTypes(func))) { + if (failed(reconcileSCFForResultTypes(func))) { signalPassFailure(); return; } diff --git a/test/basic/scf_if_for_tile_iter_arg_reconcile.pto b/test/basic/scf_if_for_tile_iter_arg_reconcile.pto new file mode 100644 index 00000000..5f22aafe --- /dev/null +++ b/test/basic/scf_if_for_tile_iter_arg_reconcile.pto @@ -0,0 +1,35 @@ +// RUN: ptoas %s 2>&1 1>/dev/null | FileCheck %s + +module { + func.func @if_for_interaction(%cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %zero = arith.constant 0.000000e+00 : f32 + + %acc0 = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + pto.tmuls ins(%acc0, %zero : !pto.tile_buf, f32) outs(%acc0 : !pto.tile_buf) + pto.tmuls ins(%tmp, %zero : !pto.tile_buf, f32) outs(%tmp : !pto.tile_buf) + + %init = scf.if %cond -> (!pto.tile_buf) { + scf.yield %acc0 : !pto.tile_buf + } else { + scf.yield %tmp : !pto.tile_buf + } + + %acc = scf.for %i = %c0 to %c2 step %c1 iter_args(%iter = %init) -> (!pto.tile_buf) { + pto.tadd ins(%iter, %tmp : !pto.tile_buf, !pto.tile_buf) outs(%tmp : !pto.tile_buf) + scf.yield %tmp : !pto.tile_buf + } + + pto.tmuls ins(%acc, %zero : !pto.tile_buf, f32) outs(%tmp : !pto.tile_buf) + return + } +} + +// CHECK: %{{.*}} = scf.if %{{.*}} -> (memref<16x1xf32 +// CHECK: scf.for +// CHECK-SAME: iter_args(%{{.*}} = %{{.*}}) -> (memref<16x1xf32 +// CHECK: scf.yield %{{.*}} : memref<16x1xf32 +// CHECK-NOT: error: