Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 7 additions & 34 deletions lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3838,51 +3838,24 @@ LogicalResult SubsetOp::inferReturnTypes(
resultShape.push_back(dim);
}

// Derive valid shape from parent valid dims when possible.
// Derive result valid shape from static sizes and parent valid shape.
// NOTE: Do not make this depend on subset offsets. Offsets may be dynamic,
// but they should not force a static valid dim to become dynamic.
SmallVector<int64_t> validShape;
constexpr int64_t kDynamicValidDim = -1;
ArrayRef<int64_t> parentValid = sourceType.getValidShape();
bool sameRank = parentValid.size() == resultShape.size();
validShape.reserve(resultShape.size());
for (size_t i = 0, e = resultShape.size(); i < e; ++i) {
int64_t sizeDim = resultShape[i];
int64_t vdim = sizeDim;

if (parentValid.size() == resultShape.size()) {
if (sameRank) {
int64_t pv = parentValid[i];
if (pv < 0) {
vdim = kDynamicValidDim;
} else {
int64_t off = 0;
// operands: [source, offsets...]
if (operands.size() > 1 + i) {
auto offOpt = getConstIndexValue(operands[1 + i]);
if (!offOpt) {
vdim = kDynamicValidDim;
validShape.push_back(vdim);
continue;
}
off = *offOpt;
// Interpret parent valid dims as a per-tile "period" when the parent
// buffer is wider than the valid region (e.g. ping/pong workspace).
// This avoids inferring a zero valid dim when taking a view at an
// offset equal to the parent valid dim.
//
// Example:
// parent: shape 32x64, valid 32x32
// subset: offset [0,32], sizes [32,32]
// should infer v_col=32 (not 0).
int64_t diff = 0;
if (pv > 0) {
int64_t offMod = off % pv;
if (offMod < 0)
offMod += pv;
diff = pv - offMod; // in [1, pv] when pv>0
}
if (diff < 0)
diff = 0;
vdim = std::min<int64_t>(sizeDim, diff);
} else {
vdim = kDynamicValidDim;
}
vdim = std::min<int64_t>(sizeDim, pv);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve static-offset clipping in subset valid-shape inference

The new inference path ignores subset offsets entirely and always uses min(size, parentValid), which regresses constant-offset cases where the valid region should shrink. For example, with parent v_col=32, sizes=[...,32], and a static col offset of 16, this now infers v_col=32 instead of 16; that overstates validity and diverges from the lowering path (computeSubsetValidDim in PTOViewToMemref) that still applies offset-based clipping. This can make earlier IR passes treat padded elements as valid and produce inconsistent metadata across the pipeline.

Useful? React with 👍 / 👎.

}
}

Expand Down
14 changes: 14 additions & 0 deletions test/basic/subset_infer_dynamic_offset.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: ptoas %s | FileCheck %s

module {
func.func @subset_infer_dynamic_offset(%off : index) {
%c0 = arith.constant 0 : index
%ws = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f16, rows=32, cols=64, v_row=32, v_col=64, blayout=row_major, slayout=none_box, fractal=512, pad=0>
%sub = "pto.subset"(%ws, %c0, %off) {sizes = [32, 64]} :
(!pto.tile_buf<loc=vec, dtype=f16, rows=32, cols=64, v_row=32, v_col=64, blayout=row_major, slayout=none_box, fractal=512, pad=0>, index, index)
-> !pto.tile_buf<loc=vec, dtype=f16, rows=32, cols=64, v_row=32, v_col=64, blayout=row_major, slayout=none_box, fractal=512, pad=0>
return
}
}

// CHECK: Success
Loading