Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions lib/PTO/Transforms/PTOViewToMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,62 @@ static LogicalResult reconcileSCFIfResultTypes(func::FuncOp func) {
return success();
}

// Ensure scf.for loop-carried types follow the rewritten init/yield operand
// types.
//
// PTOViewToMemref rewrites tile values to memref inside loop bodies, but the
// scf.for region iter_args and results are not updated automatically. For
// loop-carried values this leaves the op structurally inconsistent:
// init arg : memref<...>
// region arg : !pto.tile_buf<...>
// scf.yield : memref<...>
// scf.for res : !pto.tile_buf<...>
//
// We only reconcile when the external init operand and yielded value already
// agree on the same type; otherwise this pass has not fully lowered the loop
// consistently and we should fail loudly instead of guessing.
static LogicalResult reconcileSCFForResultTypes(func::FuncOp func) {
SmallVector<scf::ForOp, 8> forOps;
func.walk([&](scf::ForOp forOp) { forOps.push_back(forOp); });

for (scf::ForOp forOp : forOps) {
if (forOp.getNumRegionIterArgs() == 0)
continue;

auto yieldOp =
dyn_cast<scf::YieldOp>(forOp.getBody()->getTerminator());
if (!yieldOp) {
forOp.emitError("loop-carried scf.for must end with scf.yield");
return failure();
}

if (yieldOp.getNumOperands() != forOp.getNumRegionIterArgs()) {
forOp.emitError("scf.for iter_arg count does not match yielded values");
return failure();
}

for (unsigned i = 0; i < forOp.getNumRegionIterArgs(); ++i) {
Type initTy = forOp.getInitArgs()[i].getType();
Type yieldTy = yieldOp.getOperand(i).getType();
if (initTy != yieldTy) {
forOp.emitError()
<< "scf.for loop-carried type mismatch at iter_arg #" << i
<< ": init=" << initTy << ", yield=" << yieldTy;
return failure();
}

BlockArgument iterArg = forOp.getRegionIterArg(i);
if (iterArg.getType() != initTy)
iterArg.setType(initTy);

if (forOp.getResult(i).getType() != initTy)
forOp.getResult(i).setType(initTy);
}
}

return success();
}

// =============================================================================
// The Pass Implementation
// =============================================================================
Expand Down Expand Up @@ -2608,6 +2664,10 @@ struct PTOViewToMemrefPass
signalPassFailure();
return;
}
if (failed(reconcileSCFForResultTypes(func))) {
signalPassFailure();
return;
}
}

// Debug Output
Expand Down
29 changes: 29 additions & 0 deletions test/basic/scf_for_tile_iter_arg_reconcile.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: ptoas %s 2>&1 1>/dev/null | FileCheck %s

module {
func.func @for_tile_iter_arg_reconcile() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%zero = arith.constant 0.000000e+00 : f32

%acc0 = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%tmp = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tmuls ins(%acc0, %zero : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%acc0 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
pto.tmuls ins(%tmp, %zero : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>)

%acc = scf.for %i = %c0 to %c2 step %c1 iter_args(%iter = %acc0) -> (!pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>) {
pto.tadd ins(%iter, %tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
scf.yield %tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>
}

pto.tmuls ins(%acc, %zero : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
return
}
}

// CHECK: scf.for
// CHECK-SAME: iter_args(%{{.*}} = %{{.*}}) -> (memref<16x1xf32
// CHECK: pto.tadd ins(%{{.*}}, %{{.*}} : memref<16x1xf32
// CHECK: scf.yield %{{.*}} : memref<16x1xf32
// CHECK-NOT: error:
35 changes: 35 additions & 0 deletions test/basic/scf_if_for_tile_iter_arg_reconcile.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: ptoas %s 2>&1 1>/dev/null | FileCheck %s

module {
func.func @if_for_interaction(%cond: i1) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%zero = arith.constant 0.000000e+00 : f32

%acc0 = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%tmp = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tmuls ins(%acc0, %zero : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%acc0 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
pto.tmuls ins(%tmp, %zero : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>)

%init = scf.if %cond -> (!pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>) {
scf.yield %acc0 : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>
} else {
scf.yield %tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>
}

%acc = scf.for %i = %c0 to %c2 step %c1 iter_args(%iter = %init) -> (!pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>) {
pto.tadd ins(%iter, %tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
scf.yield %tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>
}

pto.tmuls ins(%acc, %zero : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>, f32) outs(%tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=row_major, slayout=none_box, fractal=512, pad=0>)
return
}
}

// CHECK: %{{.*}} = scf.if %{{.*}} -> (memref<16x1xf32
// CHECK: scf.for
// CHECK-SAME: iter_args(%{{.*}} = %{{.*}}) -> (memref<16x1xf32
// CHECK: scf.yield %{{.*}} : memref<16x1xf32
// CHECK-NOT: error:
Loading