diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index bec3eb5d38cb..844ac2f17332 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1759,6 +1759,65 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReflectionPad1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value self = adaptor.getSelf(); + auto selfTy = cast(self.getType()); + if (!selfTy.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "only support static shape"); + } + int64_t rank = selfTy.getRank(); + int64_t dim = rank - 1; + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) { + return rewriter.notifyMatchFailure(op, + "only support constant int pad ranges"); + } + if (padInts.size() != 2) { + return rewriter.notifyMatchFailure(op, "pad size must be 2"); + } + if (padInts[0] >= selfTy.getDimSize(dim) || + padInts[1] >= selfTy.getDimSize(dim)) { + return rewriter.notifyMatchFailure(op, + "pad size must be less than dim size"); + } + + Value left; + { + SmallVector startIndices(rank, 0); + SmallVector limitIndices(selfTy.getShape().begin(), + selfTy.getShape().end()); + SmallVector strides(rank, 1); + startIndices[dim] = 1; + limitIndices[dim] = padInts[0] + 1; + left = rewriter.create(loc, self, startIndices, + limitIndices, strides); + left = rewriter.create(loc, left, + ArrayRef({dim})); + } + Value right; + { + SmallVector startIndices(rank, 0); + SmallVector limitIndices(selfTy.getShape().begin(), + selfTy.getShape().end()); + SmallVector strides(rank, 1); + startIndices[dim] = selfTy.getDimSize(dim) - 1 - padInts[1]; + limitIndices[dim] = selfTy.getDimSize(dim) - 1; + right = rewriter.create(loc, self, startIndices, + limitIndices, strides); + right = rewriter.create(loc, right, + ArrayRef({dim})); + } + Value result = rewriter.create( + loc, ValueRange{left, self, right}, dim); + rewriter.replaceOp(op, result); + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluBackwardOp op, OpAdaptor adaptor, @@ -2269,6 +2328,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenScalarImplicitOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); INSERT_ATENOP_PATTERN(AtenReluOp); INSERT_ATENOP_PATTERN(AtenGeluOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d6daf2c3637a..bee2c62d7202 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -829,10 +829,6 @@ "RandnLikeDtypeModule_basic", "RandnLikeModule_basic", "RandnModule_basic", - "ReflectionPad1dModule2dInput_Right", - "ReflectionPad1dModule2dInput_basic", - "ReflectionPad1dModule3dInput_Left", - "ReflectionPad1dModule3dInput_basic", "ReflectionPad2dModule_Bottom", "ReflectionPad2dModule_Left", "ReflectionPad2dModule_Right",