diff --git a/lib/PTO/Transforms/PTOViewToMemref.cpp b/lib/PTO/Transforms/PTOViewToMemref.cpp index 9bc0b95b..4c157bf4 100644 --- a/lib/PTO/Transforms/PTOViewToMemref.cpp +++ b/lib/PTO/Transforms/PTOViewToMemref.cpp @@ -449,6 +449,62 @@ static LogicalResult reconcileSCFIfResultTypes(func::FuncOp func) { return success(); } +// Ensure scf.for loop-carried types follow the rewritten init/yield operand +// types. +// +// PTOViewToMemref rewrites tile values to memref inside loop bodies, but the +// scf.for region iter_args and results are not updated automatically. For +// loop-carried values this leaves the op structurally inconsistent: +// init arg : memref<...> +// region arg : !pto.tile_buf<...> +// scf.yield : memref<...> +// scf.for res : !pto.tile_buf<...> +// +// We only reconcile when the external init operand and yielded value already +// agree on the same type; otherwise this pass has not fully lowered the loop +// consistently and we should fail loudly instead of guessing. +static LogicalResult reconcileSCFForResultTypes(func::FuncOp func) { + SmallVector forOps; + func.walk([&](scf::ForOp forOp) { forOps.push_back(forOp); }); + + for (scf::ForOp forOp : forOps) { + if (forOp.getNumRegionIterArgs() == 0) + continue; + + auto yieldOp = + dyn_cast(forOp.getBody()->getTerminator()); + if (!yieldOp) { + forOp.emitError("loop-carried scf.for must end with scf.yield"); + return failure(); + } + + if (yieldOp.getNumOperands() != forOp.getNumRegionIterArgs()) { + forOp.emitError("scf.for iter_arg count does not match yielded values"); + return failure(); + } + + for (unsigned i = 0; i < forOp.getNumRegionIterArgs(); ++i) { + Type initTy = forOp.getInitArgs()[i].getType(); + Type yieldTy = yieldOp.getOperand(i).getType(); + if (initTy != yieldTy) { + forOp.emitError() + << "scf.for loop-carried type mismatch at iter_arg #" << i + << ": init=" << initTy << ", yield=" << yieldTy; + return failure(); + } + + BlockArgument iterArg = forOp.getRegionIterArg(i); + if (iterArg.getType() != initTy) + iterArg.setType(initTy); + + if (forOp.getResult(i).getType() != initTy) + forOp.getResult(i).setType(initTy); + } + } + + return success(); +} + // ============================================================================= // The Pass Implementation // ============================================================================= @@ -2608,6 +2664,10 @@ struct PTOViewToMemrefPass signalPassFailure(); return; } + if (failed(reconcileSCFForResultTypes(func))) { + signalPassFailure(); + return; + } } // Debug Output diff --git a/test/basic/scf_for_tile_iter_arg_reconcile.pto b/test/basic/scf_for_tile_iter_arg_reconcile.pto new file mode 100644 index 00000000..0ee5306b --- /dev/null +++ b/test/basic/scf_for_tile_iter_arg_reconcile.pto @@ -0,0 +1,29 @@ +// RUN: ptoas %s 2>&1 1>/dev/null | FileCheck %s + +module { + func.func @for_tile_iter_arg_reconcile() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %zero = arith.constant 0.000000e+00 : f32 + + %acc0 = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + pto.tmuls ins(%acc0, %zero : !pto.tile_buf, f32) outs(%acc0 : !pto.tile_buf) + pto.tmuls ins(%tmp, %zero : !pto.tile_buf, f32) outs(%tmp : !pto.tile_buf) + + %acc = scf.for %i = %c0 to %c2 step %c1 iter_args(%iter = %acc0) -> (!pto.tile_buf) { + pto.tadd ins(%iter, %tmp : !pto.tile_buf, !pto.tile_buf) outs(%tmp : !pto.tile_buf) + scf.yield %tmp : !pto.tile_buf + } + + pto.tmuls ins(%acc, %zero : !pto.tile_buf, f32) outs(%tmp : !pto.tile_buf) + return + } +} + +// CHECK: scf.for +// CHECK-SAME: iter_args(%{{.*}} = %{{.*}}) -> (memref<16x1xf32 +// CHECK: pto.tadd ins(%{{.*}}, %{{.*}} : memref<16x1xf32 +// CHECK: scf.yield %{{.*}} : memref<16x1xf32 +// CHECK-NOT: error: diff --git a/test/basic/scf_if_for_tile_iter_arg_reconcile.pto b/test/basic/scf_if_for_tile_iter_arg_reconcile.pto new file mode 100644 index 00000000..5f22aafe --- /dev/null +++ b/test/basic/scf_if_for_tile_iter_arg_reconcile.pto @@ -0,0 +1,35 @@ +// RUN: ptoas %s 2>&1 1>/dev/null | FileCheck %s + +module { + func.func @if_for_interaction(%cond: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %zero = arith.constant 0.000000e+00 : f32 + + %acc0 = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + pto.tmuls ins(%acc0, %zero : !pto.tile_buf, f32) outs(%acc0 : !pto.tile_buf) + pto.tmuls ins(%tmp, %zero : !pto.tile_buf, f32) outs(%tmp : !pto.tile_buf) + + %init = scf.if %cond -> (!pto.tile_buf) { + scf.yield %acc0 : !pto.tile_buf + } else { + scf.yield %tmp : !pto.tile_buf + } + + %acc = scf.for %i = %c0 to %c2 step %c1 iter_args(%iter = %init) -> (!pto.tile_buf) { + pto.tadd ins(%iter, %tmp : !pto.tile_buf, !pto.tile_buf) outs(%tmp : !pto.tile_buf) + scf.yield %tmp : !pto.tile_buf + } + + pto.tmuls ins(%acc, %zero : !pto.tile_buf, f32) outs(%tmp : !pto.tile_buf) + return + } +} + +// CHECK: %{{.*}} = scf.if %{{.*}} -> (memref<16x1xf32 +// CHECK: scf.for +// CHECK-SAME: iter_args(%{{.*}} = %{{.*}}) -> (memref<16x1xf32 +// CHECK: scf.yield %{{.*}} : memref<16x1xf32 +// CHECK-NOT: error: