diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index c959f06c6a66..824c6b6e0c98 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -23,6 +23,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include #include @@ -3981,69 +3982,149 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenSliceTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto selfType = dyn_cast(adaptor.getSelf().getType()); + Value self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); if (!selfType || !selfType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); - // Only statically deducible values are currently supported int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); - dim = toPositiveDim(dim, selfType.getRank()); - if (!isValidDim(dim, selfType.getRank())) - return rewriter.notifyMatchFailure(op, "dim must less than tensor rank"); + return rewriter.notifyMatchFailure(op, "dim out of range"); + + SmallVector inputShape = + llvm::to_vector(makeShapeTorchCompatible(selfType.getShape())); + const int64_t K = inputShape[dim]; int64_t start; if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); - - if (start < 0) { - start = toPositiveDim(start, selfType.getShape()[dim]); - if (!isValidDim(start, selfType.getShape()[dim])) - return rewriter.notifyMatchFailure(op, "start is not a valid index"); - } - start = std::min(selfType.getShape()[dim], start); + // Torch accepts negative `start`/`end`; translate them to positive indices in + // the canonical [0, K] range before clamping. + if (start < 0) + start = toPositiveDim(start, K); + start = std::clamp(start, /*Min=*/0, /*Max=*/K); int64_t end; - if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { - if (isa(op.getEnd().getDefiningOp())) - end = selfType.getShape()[dim]; - else - return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); - } - // support for end < 0 - end = toPositiveDim(end, selfType.getShape()[dim]); - // support for end out of upper bound - end = (end > selfType.getShape()[dim] ? selfType.getShape()[dim] : end); - - // FIXME: add support for start < 0 and end < start - if (end < start) - return rewriter.notifyMatchFailure(op, - "Currently unsupported: end < start"); + if (matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { + if (end == std::numeric_limits::max()) + end = K; + } else if (isa(op.getEnd().getDefiningOp())) { + end = K; + } else { + return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); + } + if (end < 0) + end = toPositiveDim(end, K); + end = std::clamp(end, /*Min=*/0, /*Max=*/K); int64_t step; if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure(op, "step must be a Scalar constant"); + if (step <= 0) + return rewriter.notifyMatchFailure(op, "step <= 0 unsupported"); - if (step != 1) - return rewriter.notifyMatchFailure( - op, "step value other than 1 is currently unsupported"); + auto loc = op->getLoc(); + auto elemTy = selfType.getElementType(); - SmallVector startSlice(selfType.getRank(), 0); - SmallVector sizeSlice = - llvm::to_vector(makeShapeTorchCompatible(selfType.getShape())); + auto convertedResultTy = dyn_cast_or_null( + getTypeConverter()->convertType(op.getType())); + if (!convertedResultTy || !convertedResultTy.hasStaticShape()) + return rewriter.notifyMatchFailure(op, + "result type must be statically shaped"); + + // When the stride is 1 the original tosa.slice lowering is still optimal. + if (step == 1) { + SmallVector startSlice(selfType.getRank(), 0); + SmallVector sizeSlice = inputShape; + startSlice[dim] = start; + sizeSlice[dim] = std::max(end - start, 0); + + rewriter.replaceOpWithNewOp( + op, convertedResultTy, self, + tosa::getTosaConstShape(rewriter, loc, startSlice), + tosa::getTosaConstShape(rewriter, loc, sizeSlice)); + return success(); + } + + int64_t N = 1, C = 1; + for (int64_t i = 0; i < dim; ++i) + N *= inputShape[i]; + for (int64_t i = dim + 1; i < (int64_t)inputShape.size(); ++i) + C *= inputShape[i]; + + // Stride > 1: rewrite Torch slicing into TOSA as follows: + // 1) reshape the tensor to [N, K, C] so the sliced dimension is isolated, + // 2) materialize the index vector {start + i*step}, + // 3) tile indices across the batch dimension and gather the desired rows, + // 4) reshape the gathered result back to the original rank. + // Number of elements that survive after applying the stride. + int64_t W = (end > start) ? ((end - start + step - 1) / step) : 0; + + SmallVector nkcShape = {N, K, C}; + auto nkcTy = RankedTensorType::get(makeShapeLLVMCompatible(nkcShape), elemTy); + // Reshape the input tensor into [N, K, C] so that the sliced dimension + // becomes the middle axis (K) and all prefix/suffix dimensions are grouped + // into batch (N) and channel (C) components. When the original tensor is + // already three-dimensional with this layout, reuse it directly. + Value reshaped = (inputShape.size() == 3 && inputShape[0] == N && + inputShape[1] == K && inputShape[2] == C) + ? self + : tosa::ReshapeOp::create( + rewriter, loc, nkcTy, self, + tosa::getTosaConstShape(rewriter, loc, nkcShape)) + .getResult(); + + // Build the 1-D index vector [start, start + step, ...] that encodes the + // positions we want to gather from the K dimension. + SmallVector idxVals; + idxVals.reserve(W); + for (int64_t i = 0; i < W; ++i) + idxVals.push_back(static_cast(start + i * step)); + + auto idx1DTy = RankedTensorType::get({W}, rewriter.getI32Type()); + auto idxAttr = DenseIntElementsAttr::get(idx1DTy, idxVals); + Value idx1D = + tosa::ConstOp::create(rewriter, loc, idx1DTy, idxAttr).getResult(); + + // Gather expects a 2-D index tensor, so reshape to [1, W] prior to tiling. + auto idx1xWTy = RankedTensorType::get({1, W}, rewriter.getI32Type()); + Value idx1xW = + tosa::ReshapeOp::create( + rewriter, loc, idx1xWTy, idx1D, + tosa::getTosaConstShape(rewriter, loc, SmallVector{1, W})) + .getResult(); - startSlice[dim] = start; - sizeSlice[dim] = end - start; + // Tile the single row of indices across the batch dimension so every + // [batch, channel] slice uses the same sequence. + auto tileMul = + tosa::getTosaConstShape(rewriter, loc, SmallVector{N, 1}); + auto idxNWTy = RankedTensorType::get({N, W}, rewriter.getI32Type()); + Value idxNW = + tosa::TileOp::create(rewriter, loc, idxNWTy, idx1xW, tileMul).getResult(); + + // Duplicate the 1-D index vector across the batch dimension so that we can + // use a single tosa.gather to materialize the strided slice. + auto gatherTy = RankedTensorType::get({N, W, C}, elemTy); + Value gathered = + tosa::GatherOp::create(rewriter, loc, gatherTy, reshaped, idxNW) + .getResult(); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice), - tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice)); + SmallVector outShape = inputShape; + outShape[dim] = W; + assert(llvm::equal(convertedResultTy.getShape(), outShape) && + "type converter mismatch for slice result"); + // Restore the original rank with the newly strided dimension size. + Value result = + tosa::ReshapeOp::create(rewriter, loc, convertedResultTy, gathered, + tosa::getTosaConstShape(rewriter, loc, outShape)) + .getResult(); + + rewriter.replaceOp(op, result); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0e1b3b67d102..6c3c12a944b7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3930,8 +3930,6 @@ "SliceCopyStartGreaterThanDimSize_Module_basic", "SliceEndSleStartModule_basic", "SliceOutOfLowerBoundEndIndexModule_basic", - "SliceOutOfLowerBoundStartIndexModule_basic", - "SliceSizeTwoStepModule_basic", "SortIntListReverse_basic", "SortIntList_basic", "SortTensorDescending_basic",