Skip to content

Commit 57a0bf4

Browse files
authored
[mlir][linalg] Do not set insertion point inside padding function (#165420)
Remove insertion point in rewriteAsPaddedOp. There is no gurantee that the sizes provided by the user are before the operation to pad. It's better to let the user handle where to insert the newly created operations, as long as they are after the origin operation to pad.
1 parent 49f918d commit 57a0bf4

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,9 @@ struct PadTilingInterfaceResult {
653653
// interpreted as the bounding box (dynamic) value to pad to.
654654
/// * Use "options.paddingValues" to set the padding value of the created
655655
// tensor::PadOp.
656+
//
657+
// The transformation assumes that the insertion point is set after the
658+
// operation to pad.
656659
FailureOr<PadTilingInterfaceResult>
657660
rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
658661
PadTilingInterfaceOptions options,

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,6 +2464,8 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
24642464
.setPaddingSizes(getMixedPaddingSizes())
24652465
.setPadToMultipleOf(getPadToMultipleOf());
24662466

2467+
OpBuilder::InsertionGuard g(rewriter);
2468+
rewriter.setInsertionPointAfter(targetOp);
24672469
auto maybePadOps = rewriteAsPaddedOp(
24682470
rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
24692471
if (failed(maybePadOps)) {

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,10 +288,6 @@ FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp(
288288
return failure();
289289
}
290290

291-
OpBuilder::InsertionGuard g(builder);
292-
// Set IP after toPad because we also take the dims of toPad's output.
293-
builder.setInsertionPointAfter(toPad);
294-
295291
// 1. Get the loopUpperBounds from the TilingInterface.
296292
SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder);
297293

0 commit comments

Comments
 (0)