Skip to content

Commit 06f72e4

Browse files
[TOSA] Lower boolean aten.bitwise_not to logical_not (#4364)
- Fix TorchToTosa's shared unary pattern so AtenBitwiseNotOp with i1 outputs emits tosa.logical_not instead of the tosa.bitwise_not. - Add a regression in test/Conversion/TorchToTosa/basic.mlir that checks the lowering path for a bool tensor. - Add a regression end-to-end test for bitwise_not with boolean.
1 parent a2bcca0 commit 06f72e4

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <numeric>
3131
#include <optional>
3232
#include <random>
33+
#include <type_traits>
3334

3435
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
3536

@@ -97,6 +98,16 @@ class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
9798

9899
self = tosa::tosaCastTensorToType(rewriter, self, outType).value();
99100

101+
if constexpr (std::is_same_v<AtenOpT, AtenBitwiseNotOp>) {
102+
if (auto intTy = dyn_cast<IntegerType>(outType.getElementType())) {
103+
if (intTy.getWidth() == 1) {
104+
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(op, outType, self);
105+
return success();
106+
}
107+
}
108+
// otherwise fall through to standard emission
109+
}
110+
100111
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, self);
101112

102113
return success();

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5175,6 +5175,29 @@ def ElementwiseBitwiseNotInt32Module_basic(module, tu: TestUtils):
51755175
# ==============================================================================
51765176

51775177

5178+
class ElementwiseBitwiseNotBoolModule(torch.nn.Module):
5179+
def __init__(self):
5180+
super().__init__()
5181+
5182+
@export
5183+
@annotate_args(
5184+
[
5185+
None,
5186+
([-1, -1], torch.bool, True),
5187+
]
5188+
)
5189+
def forward(self, x):
5190+
return torch.bitwise_not(x)
5191+
5192+
5193+
@register_test_case(module_factory=lambda: ElementwiseBitwiseNotBoolModule())
5194+
def ElementwiseBitwiseNotBoolModule_basic(module, tu: TestUtils):
5195+
module.forward(tu.randint(3, 4, low=0, high=2).to(torch.bool))
5196+
5197+
5198+
# ==============================================================================
5199+
5200+
51785201
class ElementwiseSubTensorInt8Module(torch.nn.Module):
51795202
def __init__(self):
51805203
super().__init__()

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,20 @@ func.func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to
135135

136136
// -----
137137

138+
// CHECK-LABEL: func.func @torch.aten.bitwise_not$bool(
139+
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],i1>) -> !torch.vtensor<[2,3],i1> {
140+
// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[2,3],i1> -> tensor<2x3xi1>
141+
// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.logical_not %[[ARG_BUILTIN]] : (tensor<2x3xi1>) -> tensor<2x3xi1>
142+
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<2x3xi1> -> !torch.vtensor<[2,3],i1>
143+
// CHECK: return %[[RESULT]] : !torch.vtensor<[2,3],i1>
144+
// CHECK: }
145+
func.func @torch.aten.bitwise_not$bool(%arg0: !torch.vtensor<[2,3],i1>) -> !torch.vtensor<[2,3],i1> {
146+
%0 = torch.aten.bitwise_not %arg0 : !torch.vtensor<[2,3],i1> -> !torch.vtensor<[2,3],i1>
147+
return %0 : !torch.vtensor<[2,3],i1>
148+
}
149+
150+
// -----
151+
138152
// CHECK-LABEL: func.func @torch.aten.ceil$basic(
139153
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
140154
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>

0 commit comments

Comments
 (0)