Skip to content

Commit b834f94

Browse files
authored
[tosa] : Add e2e support for quantized matmul. (#4371)
This PR enables e2e test for quantized torch.mm and it's other variants through the `tosa` path. `torch` IR for quantized matmul is shown in the following snippet: ``` %2 = torch.aten._make_per_tensor_quantized_tensor %0, %float2.150000e-02, %int-25 : !torch.vtensor<[3,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,4],!torch.qint8> %3 = torch.aten._make_per_tensor_quantized_tensor %1, %float1.760000e-02, %int18 : !torch.vtensor<[4,3],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.qint8> %4 = torch.aten.mm %2, %3 : !torch.vtensor<[3,4],!torch.qint8>, !torch.vtensor<[4,3],!torch.qint8> -> !torch.vtensor<[3,3],!torch.qint32> %5 = torch.aten.int_repr %4 : !torch.vtensor<[3,3],!torch.qint32> -> !torch.vtensor<[3,3],si32> %6 = torch.aten._make_per_tensor_quantized_tensor %5, %float3.784000e-04, %int0 : !torch.vtensor<[3,3],si32>, !torch.float, !torch.int -> !torch.vtensor<[3,3],!torch.qint32> %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[3,3],!torch.qint32> -> !torch.vtensor<[3,3],f32> ``` 1. This change adds legalizations for `_make_per_tensor_quantized_tensor`, `int_repr` which are basically cast operations. The former op carries the zero-point/scale information for (de)quantizing values. 2. Legalization for `dequantize.tensor` is also added which is the usual dequantization op. 3. Legalization for `matmul` is fixed to infer the zero-point information from the source `_make_per_tensor_quantized_tensor` ops for the `matmul` operands. Scale doesn't need to be considered, as it will be taken care of correctly at the output via `FuseQuantizedOps` transform.
1 parent 06f72e4 commit b834f94

File tree

7 files changed

+284
-92
lines changed

7 files changed

+284
-92
lines changed

include/torch-mlir/Conversion/Utils/Utils.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ namespace mlir {
1818
namespace torch {
1919
namespace Torch {
2020

21+
// Define constants
22+
// Float 16 limits
23+
constexpr float Float16Max = 65504.0f;
24+
constexpr float Float16Lowest = -65504.0f;
25+
26+
// BFloat 16 limits
27+
constexpr float BFloat16Max = 3.38953139e38f;
28+
constexpr float BFloat16Lowest = -3.38953139e38f;
29+
30+
// Define utility methods
2131
LogicalResult verifyLinalgCompatibleTypes(Operation *op,
2232
PatternRewriter &rewriter);
2333

@@ -107,13 +117,8 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
107117
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
108118
Value input, int64_t dim);
109119

110-
// Float 16 limits
111-
constexpr float Float16Max = 65504.0f;
112-
constexpr float Float16Lowest = -65504.0f;
120+
void getZeroPoint(Value value, Value &zeropoint);
113121

114-
// BFloat 16 limits
115-
constexpr float BFloat16Max = 3.38953139e38f;
116-
constexpr float BFloat16Lowest = -3.38953139e38f;
117122
} // namespace Torch
118123
} // namespace torch
119124
} // namespace mlir

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,6 @@ using namespace mlir::torch::Torch;
2828

2929
namespace {
3030

31-
static void getZeroPoint(Value value, Value &zeropoint) {
32-
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
33-
zeropoint = make.getZeroPoint();
34-
}
35-
}
36-
3731
// for uint8 types, we shift down by 128 so that we can faithfully
3832
// represent the quantization with signed i8 types.
3933
static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,

0 commit comments

Comments
 (0)