Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 120 additions & 39 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include <cmath>
#include <numeric>
Expand Down Expand Up @@ -3981,69 +3982,149 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
AtenSliceTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
Value self = adaptor.getSelf();
auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");

// Only statically deducible values are currently supported
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");

dim = toPositiveDim(dim, selfType.getRank());

if (!isValidDim(dim, selfType.getRank()))
return rewriter.notifyMatchFailure(op, "dim must less than tensor rank");
return rewriter.notifyMatchFailure(op, "dim out of range");

SmallVector<int64_t> inputShape =
llvm::to_vector(makeShapeTorchCompatible(selfType.getShape()));
const int64_t K = inputShape[dim];

int64_t start;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(op, "start must be a Scalar constant");

if (start < 0) {
start = toPositiveDim(start, selfType.getShape()[dim]);
if (!isValidDim(start, selfType.getShape()[dim]))
return rewriter.notifyMatchFailure(op, "start is not a valid index");
}
start = std::min(selfType.getShape()[dim], start);
// Torch accepts negative `start`/`end`; translate them to positive indices in
// the canonical [0, K] range before clamping.
if (start < 0)
start = toPositiveDim(start, K);
start = std::clamp<int64_t>(start, /*Min=*/0, /*Max=*/K);

int64_t end;
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
if (isa<ConstantNoneOp>(op.getEnd().getDefiningOp()))
end = selfType.getShape()[dim];
else
return rewriter.notifyMatchFailure(op, "end must be a Scalar constant");
}
// support for end < 0
end = toPositiveDim(end, selfType.getShape()[dim]);
// support for end out of upper bound
end = (end > selfType.getShape()[dim] ? selfType.getShape()[dim] : end);

// FIXME: add support for start < 0 and end < start
if (end < start)
return rewriter.notifyMatchFailure(op,
"Currently unsupported: end < start");
if (matchPattern(op.getEnd(), m_TorchConstantInt(&end))) {
if (end == std::numeric_limits<int64_t>::max())
end = K;
} else if (isa<ConstantNoneOp>(op.getEnd().getDefiningOp())) {
end = K;
} else {
return rewriter.notifyMatchFailure(op, "end must be a Scalar constant");
}
if (end < 0)
end = toPositiveDim(end, K);
end = std::clamp<int64_t>(end, /*Min=*/0, /*Max=*/K);

int64_t step;
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
return rewriter.notifyMatchFailure(op, "step must be a Scalar constant");
if (step <= 0)
return rewriter.notifyMatchFailure(op, "step <= 0 unsupported");

if (step != 1)
return rewriter.notifyMatchFailure(
op, "step value other than 1 is currently unsupported");
auto loc = op->getLoc();
auto elemTy = selfType.getElementType();

SmallVector<int64_t> startSlice(selfType.getRank(), 0);
SmallVector<int64_t> sizeSlice =
llvm::to_vector(makeShapeTorchCompatible(selfType.getShape()));
auto convertedResultTy = dyn_cast_or_null<RankedTensorType>(
getTypeConverter()->convertType(op.getType()));
if (!convertedResultTy || !convertedResultTy.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"result type must be statically shaped");

// When the stride is 1 the original tosa.slice lowering is still optimal.
if (step == 1) {
SmallVector<int64_t> startSlice(selfType.getRank(), 0);
SmallVector<int64_t> sizeSlice = inputShape;
startSlice[dim] = start;
sizeSlice[dim] = std::max<int64_t>(end - start, 0);

rewriter.replaceOpWithNewOp<tosa::SliceOp>(
op, convertedResultTy, self,
tosa::getTosaConstShape(rewriter, loc, startSlice),
tosa::getTosaConstShape(rewriter, loc, sizeSlice));
return success();
}

int64_t N = 1, C = 1;
for (int64_t i = 0; i < dim; ++i)
N *= inputShape[i];
for (int64_t i = dim + 1; i < (int64_t)inputShape.size(); ++i)
C *= inputShape[i];

// Stride > 1: rewrite Torch slicing into TOSA as follows:
// 1) reshape the tensor to [N, K, C] so the sliced dimension is isolated,
// 2) materialize the index vector {start + i*step},
// 3) tile indices across the batch dimension and gather the desired rows,
// 4) reshape the gathered result back to the original rank.
// Number of elements that survive after applying the stride.
int64_t W = (end > start) ? ((end - start + step - 1) / step) : 0;

SmallVector<int64_t> nkcShape = {N, K, C};
auto nkcTy = RankedTensorType::get(makeShapeLLVMCompatible(nkcShape), elemTy);
// Reshape the input tensor into [N, K, C] so that the sliced dimension
// becomes the middle axis (K) and all prefix/suffix dimensions are grouped
// into batch (N) and channel (C) components. When the original tensor is
// already three-dimensional with this layout, reuse it directly.
Value reshaped = (inputShape.size() == 3 && inputShape[0] == N &&
inputShape[1] == K && inputShape[2] == C)
? self
: tosa::ReshapeOp::create(
rewriter, loc, nkcTy, self,
tosa::getTosaConstShape(rewriter, loc, nkcShape))
.getResult();

// Build the 1-D index vector [start, start + step, ...] that encodes the
// positions we want to gather from the K dimension.
SmallVector<int32_t> idxVals;
idxVals.reserve(W);
for (int64_t i = 0; i < W; ++i)
idxVals.push_back(static_cast<int32_t>(start + i * step));

auto idx1DTy = RankedTensorType::get({W}, rewriter.getI32Type());
auto idxAttr = DenseIntElementsAttr::get(idx1DTy, idxVals);
Value idx1D =
tosa::ConstOp::create(rewriter, loc, idx1DTy, idxAttr).getResult();

// Gather expects a 2-D index tensor, so reshape to [1, W] prior to tiling.
auto idx1xWTy = RankedTensorType::get({1, W}, rewriter.getI32Type());
Value idx1xW =
tosa::ReshapeOp::create(
rewriter, loc, idx1xWTy, idx1D,
tosa::getTosaConstShape(rewriter, loc, SmallVector<int64_t>{1, W}))
.getResult();

startSlice[dim] = start;
sizeSlice[dim] = end - start;
// Tile the single row of indices across the batch dimension so every
// [batch, channel] slice uses the same sequence.
auto tileMul =
tosa::getTosaConstShape(rewriter, loc, SmallVector<int64_t>{N, 1});
auto idxNWTy = RankedTensorType::get({N, W}, rewriter.getI32Type());
Value idxNW =
tosa::TileOp::create(rewriter, loc, idxNWTy, idx1xW, tileMul).getResult();

// Duplicate the 1-D index vector across the batch dimension so that we can
// use a single tosa.gather to materialize the strided slice.
auto gatherTy = RankedTensorType::get({N, W, C}, elemTy);
Value gathered =
tosa::GatherOp::create(rewriter, loc, gatherTy, reshaped, idxNW)
.getResult();

rewriter.replaceOpWithNewOp<tosa::SliceOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice),
tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice));
SmallVector<int64_t> outShape = inputShape;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: seems like unnecessary copy here as we do not write to inputShape after this. I can see why its done as it would be much easier extend the function in the future withtout having to go back and edit this.. and there is also a readability side so happy to ignore if this is on purpose.

outShape[dim] = W;
assert(llvm::equal(convertedResultTy.getShape(), outShape) &&
"type converter mismatch for slice result");

// Restore the original rank with the newly strided dimension size.
Value result =
tosa::ReshapeOp::create(rewriter, loc, convertedResultTy, gathered,
tosa::getTosaConstShape(rewriter, loc, outShape))
.getResult();

rewriter.replaceOp(op, result);
return success();
}

Expand Down
2 changes: 0 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3930,8 +3930,6 @@
"SliceCopyStartGreaterThanDimSize_Module_basic",
"SliceEndSleStartModule_basic",
"SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceSizeTwoStepModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
"SortTensorDescending_basic",
Expand Down
Loading