From dcc4050d2a824ccb99989c16e1986dc84d9ba5fa Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Mon, 10 Nov 2025 19:14:30 -0500 Subject: [PATCH] [tosa] : Add e2e support for quantized matmul. --- include/torch-mlir/Conversion/Utils/Utils.h | 17 +- lib/Conversion/TorchToLinalg/Linear.cpp | 6 - lib/Conversion/TorchToTosa/TorchToTosa.cpp | 236 ++++++++++++++---- lib/Conversion/Utils/Utils.cpp | 6 + projects/pt1/e2e_testing/xfail_sets.py | 5 - test/Conversion/TorchToTosa/basic.mlir | 62 ++--- test/Conversion/TorchToTosa/quantization.mlir | 44 ++++ 7 files changed, 284 insertions(+), 92 deletions(-) create mode 100644 test/Conversion/TorchToTosa/quantization.mlir diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 5173e7f82c4b..17c0ff23c74c 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -18,6 +18,16 @@ namespace mlir { namespace torch { namespace Torch { +// Define constants +// Float 16 limits +constexpr float Float16Max = 65504.0f; +constexpr float Float16Lowest = -65504.0f; + +// BFloat 16 limits +constexpr float BFloat16Max = 3.38953139e38f; +constexpr float BFloat16Lowest = -3.38953139e38f; + +// Define utility methods LogicalResult verifyLinalgCompatibleTypes(Operation *op, PatternRewriter &rewriter); @@ -107,13 +117,8 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, Value input, int64_t dim); -// Float 16 limits -constexpr float Float16Max = 65504.0f; -constexpr float Float16Lowest = -65504.0f; +void getZeroPoint(Value value, Value &zeropoint); -// BFloat 16 limits -constexpr float BFloat16Max = 3.38953139e38f; -constexpr float BFloat16Lowest = -3.38953139e38f; } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 68947a953b7a..7ff30d46f2ee 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -28,12 +28,6 @@ using namespace mlir::torch::Torch; namespace { -static void getZeroPoint(Value value, Value &zeropoint) { - if (auto make = value.getDefiningOp()) { - zeropoint = make.getZeroPoint(); - } -} - // for uint8 types, we shift down by 128 so that we can faithfully // represent the quantization with signed i8 types. static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index c959f06c6a66..b5eeadfe0b89 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1481,14 +1481,16 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // and rhs. virtual LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, - Value &lhs, Value &rhs) const { + Value &lhs, Value &rhs, Value &lhsZp, + Value &rhsZp) const { return rewriter.notifyMatchFailure( op, "Unimplemented matrix multiplication variant input parsing function"); } LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, Value &lhs, - Value &rhs, Value &output) const { + Value &rhs, Value &lhsZp, Value &rhsZp, + Value &output) const { auto lhsTy = cast(lhs.getType()); auto rhsTy = cast(rhs.getType()); @@ -1506,6 +1508,33 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Matmul: input datatypes mismatched"); + if (!lhsZp) { + // Initialize zero constant values as zero-points, if the op operands + // aren't quantized types + lhsZp = tosa::createZeroPointTensor(rewriter, op->getLoc(), lhsElemTy, 0) + .value(); + rhsZp = tosa::createZeroPointTensor(rewriter, op->getLoc(), rhsElemTy, 0) + .value(); + } else { + + int64_t lhsZpConst, rhsZpConst; + if (!matchPattern(lhsZp, m_TorchConstantInt(&lhsZpConst))) + return rewriter.notifyMatchFailure( + op, "Lhs zero point must be a Scalar constant"); + + lhsZp = tosa::createZeroPointTensor(rewriter, op->getLoc(), lhsElemTy, + lhsZpConst) + .value(); + + if (!matchPattern(rhsZp, m_TorchConstantInt(&rhsZpConst))) + return rewriter.notifyMatchFailure( + op, "Rhs zero point must be a Scalar constant"); + + rhsZp = tosa::createZeroPointTensor(rewriter, op->getLoc(), rhsElemTy, + rhsZpConst) + .value(); + } + // Legalization constructs may offer input shapes but expect output shapes // to be inferred, e.g. // func @forward(%arg0: !torch.vtensor<[14,19],f32>, @@ -1873,44 +1902,18 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { SmallVector matmulOutputShape( {matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]}); - bool isInputElemTyQInt8 = false; Type inputElemTy{lhsElemTy}; - if (auto inputQTy = - dyn_cast(lhsElemTy)) { - if (inputQTy.getStorageTypeIntegralWidth() == 8) - isInputElemTyQInt8 = true; - inputElemTy = inputQTy.getStorageType(); - } - auto accElemTy = getDefaultAccType(rewriter, inputElemTy); auto mmOutputTy = RankedTensorType::get( makeShapeLLVMCompatible(matmulOutputShape), accElemTy); - Value mmOpResult; - if (!isInputElemTyQInt8) { - // LHS and RHS tensors' zero points must be zero for non-int8 types - Value lhsZp = - tosa::createZeroPointTensor(rewriter, op->getLoc(), lhsElemTy, 0) - .value(); - Value rhsZp = - tosa::createZeroPointTensor(rewriter, op->getLoc(), rhsElemTy, 0) - .value(); - mmOpResult = - tosa::MatMulOp::create( - rewriter, op->getLoc(), - OpConversionPattern::getTypeConverter()->convertType( - mmOutputTy), - matmulLhs, matmulRhs, lhsZp, rhsZp) - .getResult(); - } else { - mmOpResult = - tosa::MatMulOp::create( - rewriter, op->getLoc(), - OpConversionPattern::getTypeConverter()->convertType( - mmOutputTy), - matmulLhs, matmulRhs) - .getResult(); - } + Value mmOpResult = + tosa::MatMulOp::create( + rewriter, op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + mmOutputTy), + matmulLhs, matmulRhs, lhsZp, rhsZp) + .getResult(); // Perform the reshape to output shape. This is always required unless max // input rank=3 and there was no broadcasting, in which case the tosa.matmul @@ -2047,14 +2050,15 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value lhs, rhs; + Value lhs, rhs, lhsZp, rhsZp; - if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) + if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs, lhsZp, rhsZp))) return rewriter.notifyMatchFailure(op, "Failed to read matmul inputs"); Value output; - if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output))) + if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, lhsZp, rhsZp, + output))) return rewriter.notifyMatchFailure(op, "Failed to perform matmul operation"); @@ -2079,7 +2083,8 @@ class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, - Value &lhs, Value &rhs) const override { + Value &lhs, Value &rhs, Value &lhsZp, + Value &rhsZp) const override { lhs = adaptor.getSelf(); auto lhsTy = cast(lhs.getType()); @@ -2090,6 +2095,19 @@ class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA matmul"); + // Get values from the op itself instead of the adaptor (that returns + // converted values in 'tensor' type instead of 'vtensor') so that the + // connection with source torch ops carrying the quantization information + // is preserved and use-def analysis can be performed to extract such + // information. + getZeroPoint(op.getSelf(), lhsZp); + getZeroPoint(op.getOther(), rhsZp); + + if (static_cast(lhsZp) != static_cast(rhsZp)) { + return rewriter.notifyMatchFailure( + op, "unsupported: aten.matmul with mixed quantization"); + } + return success(); } }; @@ -2102,7 +2120,8 @@ class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, - Value &lhs, Value &rhs) const override { + Value &lhs, Value &rhs, Value &lhsZp, + Value &rhsZp) const override { lhs = adaptor.getSelf(); auto lhsTy = cast(lhs.getType()); @@ -2127,6 +2146,14 @@ class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { return op.emitError("aten.bmm called but matrix rank != 3"); } + getZeroPoint(op.getSelf(), lhsZp); + getZeroPoint(op.getMat2(), rhsZp); + + if (static_cast(lhsZp) != static_cast(rhsZp)) { + return rewriter.notifyMatchFailure( + op, "unsupported: aten.mm/aten.bmm with mixed quantization"); + } + return success(); } }; @@ -2139,7 +2166,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, - Value &lhs, Value &rhs) const override { + Value &lhs, Value &rhs, Value &lhsZp, + Value &rhsZp) const override { lhs = adaptor.getInput(); auto lhsTy = cast(lhs.getType()); @@ -2165,6 +2193,14 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { return rewriter.notifyMatchFailure( op, "aten.Linear needs statically shaped input"); + getZeroPoint(op.getInput(), lhsZp); + getZeroPoint(op.getWeight(), rhsZp); + + if (static_cast(lhsZp) != static_cast(rhsZp)) { + return rewriter.notifyMatchFailure( + op, "unsupported: aten.Linear with mixed quantization"); + } + return success(); } // Override the default rewriter to perform RHS transpose and bias addition as @@ -2173,9 +2209,9 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value lhs, rhs; + Value lhs, rhs, lhsZp, rhsZp; - if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs))) + if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs, lhsZp, rhsZp))) return rewriter.notifyMatchFailure(op, "Failed to read matmul inputs"); // The aten.Linear op has a bias tensor that is added to the matmul output. @@ -2215,8 +2251,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { rhs, rewriter.getDenseI32ArrayAttr(transposedRhsDims)); Value matmulOutput; - if (failed( - this->performMatmul(op, adaptor, rewriter, lhs, rhs, matmulOutput))) + if (failed(this->performMatmul(op, adaptor, rewriter, lhs, rhs, lhsZp, + rhsZp, matmulOutput))) return rewriter.notifyMatchFailure(op, "Failed to perform matmul operation"); @@ -9026,6 +9062,107 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template +class ConvertCastEquivalentOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; + + LogicalResult + matchAndRewrite(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = this->getTypeConverter(); + RankedTensorType resultType = cast( + converter->convertType(op->getResult(0).getType())); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); + return success(); + } +}; + +// Legalization for aten.dequantize.tensor +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenDequantizeTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto loc = op->getLoc(); + auto qtensor = adaptor.getQtensor(); + + // Find the zero_point and scale values from the op itself instead of the + // adaptor (that returns converted values in 'tensor' type instead of + // 'vtensor') so that the connection with source torch ops carrying the + // quantization information is preserved and use-def analysis can be performed + // to extract such information. + Value zp, scale; + auto value = op.getQtensor(); + if (auto makeQTensor = + value.getDefiningOp()) { + zp = makeQTensor.getZeroPoint(); + scale = makeQTensor.getScale(); + } else if (auto quant = value.getDefiningOp()) { + zp = quant.getZeroPoint(); + scale = quant.getScale(); + } + + if (!zp || !scale) { + return rewriter.notifyMatchFailure(op, + "could not find quantization params"); + } + + int64_t zpConst; + double scaleConst; + if (!matchPattern(zp, m_TorchConstantInt(&zpConst))) + return rewriter.notifyMatchFailure(op, + "zero point must be a Scalar constant"); + + if (!matchPattern(scale, m_TorchConstantFloat(&scaleConst))) + return rewriter.notifyMatchFailure(op, "scale must be a Scalar constant"); + + // Get result types + auto resultTensorTy = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + auto elemFpTy = resultTensorTy.getElementType(); + auto shape = resultTensorTy.getShape(); + + // Define intermediate integer calculation type using the same bitwidth as the + // output float + auto elemIntTy = rewriter.getIntegerType(elemFpTy.getIntOrFloatBitWidth()); + auto intTensorTy = RankedTensorType::get(shape, elemIntTy); + + // Cast quantized input to intermediate integer type + // tosa.cast handles i8 -> i32 extension (as needed) + Value intValue = tosa::CastOp::create(rewriter, loc, intTensorTy, qtensor); + + // Subtract: (value - zero_point) + Value intZp = + tosa::createZeroPointTensor(rewriter, op->getLoc(), elemIntTy, zpConst) + .value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), intValue, intZp) + .failed()) + return failure(); + + Value subResult = + tosa::SubOp::create(rewriter, loc, intTensorTy, intValue, intZp); + + // Multiply: (value - zero_point) * scale + // Both operands is cast to result type + Value floatResult = + tosa::CastOp::create(rewriter, loc, resultTensorTy, subResult); + + auto scaleType = resultTensorTy.clone(elemFpTy); + auto scaleAttr = DenseElementsAttr::get( + scaleType, rewriter.getFloatAttr(elemFpTy, scaleConst)); + auto floatScale = + tosa::ConstOp::create(rewriter, op->getLoc(), scaleType, scaleAttr); + + auto mulResult = tosa::createMulOpAndCast( + rewriter, op, resultTensorTy, floatResult, floatScale, /*shift=*/0); + + rewriter.replaceOp(op, mulResult); + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -9400,6 +9537,7 @@ std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( INSERT_ATENOP_PATTERN(AtenExpm1Op); INSERT_ATENOP_PATTERN(AtenTanOp); INSERT_ATENOP_PATTERN(AtenUnfoldOp); + INSERT_ATENOP_PATTERN(AtenDequantizeTensorOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ @@ -9408,6 +9546,14 @@ std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); #undef INSERT_CLONE_ATENOP_PATTERN +#define INSERT_CAST_ATENOP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_CAST_ATENOP_PATTERN(Aten_MakePerChannelQuantizedTensorOp); + INSERT_CAST_ATENOP_PATTERN(Aten_MakePerTensorQuantizedTensorOp); + INSERT_CAST_ATENOP_PATTERN(AtenIntReprOp); +#undef INSERT_CAST_ATENOP_PATTERN + return illegalOps; } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 0bff646023cf..820f82afea5a 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -565,6 +565,12 @@ FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, return squeezed; } +void getZeroPoint(Value value, Value &zeropoint) { + if (auto make = value.getDefiningOp()) { + zeropoint = make.getZeroPoint(); + } +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0e1b3b67d102..5fc42cc1db1d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3642,12 +3642,7 @@ "AtenItemIntOpModule_basic", "AtenMatmulQMixedSigni8Transpose_basic", "AtenMatmulQMixedSigni8_basic", - "AtenMatmulQint8MV_basic", - "AtenMatmulQint8VM_basic", - "AtenMatmulQint8VV_basic", - "AtenMatmulQint8_basic", "AtenMmQMixedSigni8_basic", - "AtenMmQint8_basic", "AtenMmQuint8_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 12f971cc9767..06dd7aa15cc9 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -4260,21 +4260,23 @@ func.func @torch.aten.convolution$si8(%arg0: !torch.vtensor<[2,2,6,6],si8>, %arg %4 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %1, %2, %false, %3, %int1 : !torch.vtensor<[2,2,6,6],si8>, !torch.vtensor<[8,2,3,3],si8>, !torch.vtensor<[8],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[2,8,4,4],si32> return %4 : !torch.vtensor<[2,8,4,4],si32> } + +// ----- // CHECK-LABEL: func.func @torch.aten.mm$f32( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[1,22],f32>, -// CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[22,10],f32>) -> !torch.vtensor<[1,10],f32> { +// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[1,22],f32>, +// CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[22,10],f32>) -> !torch.vtensor<[1,10],f32> { // CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[22,10],f32> -> tensor<22x10xf32> // CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[1,22],f32> -> tensor<1x22xf32> -// CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 22]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<1x22xf32>, !tosa.shape<3>) -> tensor<1x1x22xf32> -// CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 22, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TENSOR]], %[[WTS_SHAPE]] : (tensor<22x10xf32>, !tosa.shape<3>) -> tensor<1x22x10xf32> // CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x1x22xf32>, tensor<1x22x10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x10xf32> +// CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 22]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<1x22xf32>, !tosa.shape<3>) -> tensor<1x1x22xf32> +// CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 22, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS_TENSOR]], %[[WTS_SHAPE]] : (tensor<22x10xf32>, !tosa.shape<3>) -> tensor<1x22x10xf32> +// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x1x22xf32>, tensor<1x22x10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x1x10xf32> // CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 10]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x1x10xf32>, !tosa.shape<2>) -> tensor<1x10xf32> -// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[RES_RESHAPE]] : tensor<1x10xf32> -> !torch.vtensor<[1,10],f32> +// CHECK: %[[RES_RESHAPED:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x1x10xf32>, !tosa.shape<2>) -> tensor<1x10xf32> +// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[RES_RESHAPED]] : tensor<1x10xf32> -> !torch.vtensor<[1,10],f32> // CHECK: return %[[RES]] func.func @torch.aten.mm$f32(%arg0: !torch.vtensor<[1,22],f32>, %arg1: !torch.vtensor<[22,10],f32>) -> !torch.vtensor<[1,10],f32> { %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[1,22],f32>, !torch.vtensor<[22,10],f32> -> !torch.vtensor<[1,10],f32> @@ -4319,23 +4321,23 @@ func.func @torch.aten.mm$bf16(%arg0: !torch.vtensor<[1,22],bf16>, %arg1: !torch. // ----- // CHECK-LABEL: func.func @torch.aten.matmul$broadcast( -// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[10,3,4],f32>, -// CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[10,3],f32> { +// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[10,3,4],f32>, +// CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[10,3],f32> { // CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[4],f32> -> tensor<4xf32> // CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32> -// CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TENSOR]], %[[WTS_SHAPE]] : (tensor<4xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32> -// CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 30, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<10x3x4xf32>, !tosa.shape<3>) -> tensor<1x30x4xf32> -// CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[WTS_RESHAPE]] {perms = array} : (tensor<1x4x1xf32>) -> tensor<4x1x1xf32> -// CHECK: %[[WTS_SHAPE_2:.*]] = tosa.const_shape {values = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[WTS_RESHAPE_2:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE_2]] : (tensor<4x1x1xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32> // CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE_2]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x30x4xf32>, tensor<1x4x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x30x1xf32> +// CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[INP_RESHAPED:.*]] = tosa.reshape %[[WTS_TENSOR]], %[[INP_SHAPE]] : (tensor<4xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32> +// CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 30, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[WTS_RESHAPED:.*]] = tosa.reshape %[[INP_TENSOR]], %[[WTS_SHAPE]] : (tensor<10x3x4xf32>, !tosa.shape<3>) -> tensor<1x30x4xf32> +// CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[INP_RESHAPED]] {perms = array} : (tensor<1x4x1xf32>) -> tensor<4x1x1xf32> +// CHECK: %[[WTS_SHAPE_2:.*]] = tosa.const_shape {values = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[WTS_RESHAPED_2:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE_2]] : (tensor<4x1x1xf32>, !tosa.shape<3>) -> tensor<1x4x1xf32> +// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[WTS_RESHAPED]], %[[WTS_RESHAPED_2]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x30x4xf32>, tensor<1x4x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x30x1xf32> // CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[10, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x30x1xf32>, !tosa.shape<2>) -> tensor<10x3xf32> -// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[RES_RESHAPE]] : tensor<10x3xf32> -> !torch.vtensor<[10,3],f32> +// CHECK: %[[RES_RESHAPED:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x30x1xf32>, !tosa.shape<2>) -> tensor<10x3xf32> +// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[RES_RESHAPED]] : tensor<10x3xf32> -> !torch.vtensor<[10,3],f32> // CHECK: return %[[RES]] func.func @torch.aten.matmul$broadcast(%arg0: !torch.vtensor<[10,3,4],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[10,3],f32> { %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,3,4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[10,3],f32> @@ -4351,19 +4353,19 @@ func.func @torch.aten.matmul$broadcast(%arg0: !torch.vtensor<[10,3,4],f32>, %arg // CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[3,4],f16> -> tensor<3x4xf16> // CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[2,4],f16> -> tensor<2x4xf16> // CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[WTS_TENSOR]] {perms = array} : (tensor<3x4xf16>) -> tensor<4x3xf16> -// CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 2, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<2x4xf16>, !tosa.shape<3>) -> tensor<1x2x4xf16> -// CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE]] : (tensor<4x3xf16>, !tosa.shape<3>) -> tensor<1x4x3xf16> // CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16> // CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16> -// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x2x4xf16>, tensor<1x4x3xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x2x3xf32> +// CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 2, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<2x4xf16>, !tosa.shape<3>) -> tensor<1x2x4xf16> +// CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE]] : (tensor<4x3xf16>, !tosa.shape<3>) -> tensor<1x4x3xf16> +// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x2x4xf16>, tensor<1x4x3xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x2x3xf32> // CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[2, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x2x3xf32>, !tosa.shape<2>) -> tensor<2x3xf32> -// CHECK: %[[CAST:.*]] = tosa.cast %[[RES_RESHAPE]] : (tensor<2x3xf32>) -> tensor<2x3xf16> +// CHECK: %[[RES_RESHAPED:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x2x3xf32>, !tosa.shape<2>) -> tensor<2x3xf32> +// CHECK: %[[CAST:.*]] = tosa.cast %[[RES_RESHAPED]] : (tensor<2x3xf32>) -> tensor<2x3xf16> // CHECK: %[[BIAS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK: %[[BIAS_RESHAPE:.*]] = tosa.reshape %[[BIAS_TENSOR]], %[[BIAS_SHAPE]] : (tensor<3xf16>, !tosa.shape<2>) -> tensor<1x3xf16> -// CHECK: %[[ADD:.*]] = tosa.add %[[CAST]], %[[BIAS_RESHAPE]] : (tensor<2x3xf16>, tensor<1x3xf16>) -> tensor<2x3xf16> +// CHECK: %[[BIAS_RESHAPED:.*]] = tosa.reshape %[[BIAS_TENSOR]], %[[BIAS_SHAPE]] : (tensor<3xf16>, !tosa.shape<2>) -> tensor<1x3xf16> +// CHECK: %[[ADD:.*]] = tosa.add %[[CAST]], %[[BIAS_RESHAPED]] : (tensor<2x3xf16>, tensor<1x3xf16>) -> tensor<2x3xf16> // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[ADD]] : tensor<2x3xf16> -> !torch.vtensor<[2,3],f16> // CHECK: return %[[RES]] func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch.vtensor<[3,4],f16>, %arg2: !torch.vtensor<[3],f16>) -> !torch.vtensor<[2,3],f16> { diff --git a/test/Conversion/TorchToTosa/quantization.mlir b/test/Conversion/TorchToTosa/quantization.mlir new file mode 100644 index 000000000000..74eef9a496d1 --- /dev/null +++ b/test/Conversion/TorchToTosa/quantization.mlir @@ -0,0 +1,44 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa --canonicalize -split-input-file | FileCheck %s +// COM: --canonicalize is used to clean up the IR after conversion to make resulting IR easier to read + + +// CHECK-LABEL: func.func @AtenMmQint8( +// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[3,4],si8>, +// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[4,3],si8>) -> !torch.vtensor<[3,3],f32> { +// CHECK: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[OUT_SCALE:.*]] = "tosa.const"() <{values = dense<3.784000e-04> : tensor<3x3xf32>}> : () -> tensor<3x3xf32> +// CHECK-DAG: %[[MUL_OUT_SHAPE:.*]] = tosa.const_shape {values = dense<3> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[RHS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[LHS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[RHS_ZP:.*]] = "tosa.const"() <{values = dense<18> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[LHS_ZP:.*]] = "tosa.const"() <{values = dense<-25> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[RHS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[RHS]] : !torch.vtensor<[4,3],si8> -> tensor<4x3xi8> +// CHECK: %[[LHS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[LHS]] : !torch.vtensor<[3,4],si8> -> tensor<3x4xi8> +// CHECK: %[[LHS_RESHAPED:.*]] = tosa.reshape %[[LHS_TENSOR]], %[[LHS_SHAPE]] : (tensor<3x4xi8>, !tosa.shape<3>) -> tensor<1x3x4xi8> +// CHECK: %[[RHS_RESHAPED:.*]] = tosa.reshape %[[RHS_TENSOR]], %[[RHS_SHAPE]] : (tensor<4x3xi8>, !tosa.shape<3>) -> tensor<1x4x3xi8> +// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[LHS_RESHAPED]], %[[RHS_RESHAPED]], %[[LHS_ZP]], %[[RHS_ZP]] : (tensor<1x3x4xi8>, tensor<1x4x3xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x3x3xi32> +// CHECK: %[[MATMUL_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[MUL_OUT_SHAPE]] : (tensor<1x3x3xi32>, !tosa.shape<2>) -> tensor<3x3xi32> +// CHECK: %[[MATMUL_FP32:.*]] = tosa.cast %[[MATMUL_RESHAPE]] : (tensor<3x3xi32>) -> tensor<3x3xf32> +// CHECK: %[[OUT_SCALED:.*]] = tosa.mul %[[MATMUL_FP32]], %[[OUT_SCALE]], %[[SHIFT]] : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi8>) -> tensor<3x3xf32> +// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[OUT_SCALED]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32> +// CHECK: return %[[RES]] +func.func @AtenMmQint8(%arg0: !torch.vtensor<[3,4],si8>, %arg1: !torch.vtensor<[4,3],si8>) -> !torch.vtensor<[3,3],f32> +{ + %float3.784000e-04 = torch.constant.float 3.784000e-04 + %int0 = torch.constant.int 0 + %int18 = torch.constant.int 18 + %float1.760000e-02 = torch.constant.float 1.760000e-02 + %float2.150000e-02 = torch.constant.float 2.150000e-02 + %int-25 = torch.constant.int -25 + %int-128 = torch.constant.int -128 + %int127 = torch.constant.int 127 + %0 = torch.aten.clamp %arg0, %int-128, %int127 : !torch.vtensor<[3,4],si8>, !torch.int, !torch.int -> !torch.vtensor<[3,4],si8> + %1 = torch.aten.clamp %arg1, %int-128, %int127 : !torch.vtensor<[4,3],si8>, !torch.int, !torch.int -> !torch.vtensor<[4,3],si8> + %2 = torch.aten._make_per_tensor_quantized_tensor %0, %float2.150000e-02, %int-25 : !torch.vtensor<[3,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,4],!torch.qint8> + %3 = torch.aten._make_per_tensor_quantized_tensor %1, %float1.760000e-02, %int18 : !torch.vtensor<[4,3],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.qint8> + %4 = torch.aten.mm %2, %3 : !torch.vtensor<[3,4],!torch.qint8>, !torch.vtensor<[4,3],!torch.qint8> -> !torch.vtensor<[3,3],!torch.qint32> + %5 = torch.aten.int_repr %4 : !torch.vtensor<[3,3],!torch.qint32> -> !torch.vtensor<[3,3],si32> + %6 = torch.aten._make_per_tensor_quantized_tensor %5, %float3.784000e-04, %int0 : !torch.vtensor<[3,3],si32>, !torch.float, !torch.int -> !torch.vtensor<[3,3],!torch.qint32> + %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[3,3],!torch.qint32> -> !torch.vtensor<[3,3],f32> + return %7 : !torch.vtensor<[3,3],f32> +}