From e0733a11f208a14944242c3635318770f7d32aba Mon Sep 17 00:00:00 2001 From: lishengtao Date: Fri, 13 Mar 2026 11:26:34 +0800 Subject: [PATCH] Fix subset valid-shape inference with dynamic offsets --- lib/PTO/IR/PTO.cpp | 41 ++++----------------- test/basic/subset_infer_dynamic_offset.mlir | 14 +++++++ 2 files changed, 21 insertions(+), 34 deletions(-) create mode 100644 test/basic/subset_infer_dynamic_offset.mlir diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 02670518..5c316a78 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -3838,51 +3838,24 @@ LogicalResult SubsetOp::inferReturnTypes( resultShape.push_back(dim); } - // Derive valid shape from parent valid dims when possible. + // Derive result valid shape from static sizes and parent valid shape. + // NOTE: Do not make this depend on subset offsets. Offsets may be dynamic, + // but they should not force a static valid dim to become dynamic. SmallVector validShape; constexpr int64_t kDynamicValidDim = -1; ArrayRef parentValid = sourceType.getValidShape(); + bool sameRank = parentValid.size() == resultShape.size(); + validShape.reserve(resultShape.size()); for (size_t i = 0, e = resultShape.size(); i < e; ++i) { int64_t sizeDim = resultShape[i]; int64_t vdim = sizeDim; - if (parentValid.size() == resultShape.size()) { + if (sameRank) { int64_t pv = parentValid[i]; if (pv < 0) { vdim = kDynamicValidDim; } else { - int64_t off = 0; - // operands: [source, offsets...] - if (operands.size() > 1 + i) { - auto offOpt = getConstIndexValue(operands[1 + i]); - if (!offOpt) { - vdim = kDynamicValidDim; - validShape.push_back(vdim); - continue; - } - off = *offOpt; - // Interpret parent valid dims as a per-tile "period" when the parent - // buffer is wider than the valid region (e.g. ping/pong workspace). - // This avoids inferring a zero valid dim when taking a view at an - // offset equal to the parent valid dim. - // - // Example: - // parent: shape 32x64, valid 32x32 - // subset: offset [0,32], sizes [32,32] - // should infer v_col=32 (not 0). - int64_t diff = 0; - if (pv > 0) { - int64_t offMod = off % pv; - if (offMod < 0) - offMod += pv; - diff = pv - offMod; // in [1, pv] when pv>0 - } - if (diff < 0) - diff = 0; - vdim = std::min(sizeDim, diff); - } else { - vdim = kDynamicValidDim; - } + vdim = std::min(sizeDim, pv); } } diff --git a/test/basic/subset_infer_dynamic_offset.mlir b/test/basic/subset_infer_dynamic_offset.mlir new file mode 100644 index 00000000..d326bdf7 --- /dev/null +++ b/test/basic/subset_infer_dynamic_offset.mlir @@ -0,0 +1,14 @@ +// RUN: ptoas %s | FileCheck %s + +module { + func.func @subset_infer_dynamic_offset(%off : index) { + %c0 = arith.constant 0 : index + %ws = pto.alloc_tile : !pto.tile_buf + %sub = "pto.subset"(%ws, %c0, %off) {sizes = [32, 64]} : + (!pto.tile_buf, index, index) + -> !pto.tile_buf + return + } +} + +// CHECK: Success