Skip to content

Commit 94d96a1

Browse files
Arm backend: Add test for BatchNorm1D QAT folding (#16001)
Adds new tests for Conv1d+BatchNorm1d QAT and make sure that in-place hardtanh can be fused with convolutions/linear as well. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent a93f59e commit 94d96a1

File tree

2 files changed

+87
-22
lines changed

2 files changed

+87
-22
lines changed

backends/arm/quantizer/quantization_annotator.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import torch
1818
import torch.fx
19-
import torch.nn.functional as F
2019
from executorch.backends.arm.common.debug import get_node_debug_info
2120
from executorch.backends.arm.common.type import ensure_type
2221
from executorch.backends.arm.quantizer import QuantizationConfig
@@ -477,7 +476,11 @@ def get_quant_properties( # noqa: C901
477476
def any_or_hardtanh_min_zero(n: Node):
478477
"""Return True for any op or hardtanh with ``min_val == 0``."""
479478
# Check that if the node is a hardtanh, its min_val is zero
480-
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0
479+
return (
480+
n.target
481+
not in (torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default)
482+
or n.args[1] == 0
483+
)
481484

482485
if _match_pattern(
483486
node,
@@ -487,11 +490,14 @@ def any_or_hardtanh_min_zero(n: Node):
487490
torch.ops.aten.conv2d.default,
488491
torch.ops.aten.conv2d.padding,
489492
],
490-
[torch.ops.aten.batch_norm.default, F.batch_norm],
493+
[
494+
torch.ops.aten.batch_norm.default,
495+
],
491496
[
492497
torch.ops.aten.relu.default,
493498
torch.ops.aten.relu_.default,
494499
torch.ops.aten.hardtanh.default,
500+
torch.ops.aten.hardtanh_.default,
495501
],
496502
],
497503
filter_fn=any_or_hardtanh_min_zero,
@@ -510,6 +516,7 @@ def any_or_hardtanh_min_zero(n: Node):
510516
torch.ops.aten.relu.default,
511517
torch.ops.aten.relu_.default,
512518
torch.ops.aten.hardtanh.default,
519+
torch.ops.aten.hardtanh_.default,
513520
):
514521
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
515522

@@ -521,7 +528,9 @@ def any_or_hardtanh_min_zero(n: Node):
521528
torch.ops.aten.conv2d.default,
522529
torch.ops.aten.conv2d.padding,
523530
],
524-
[torch.ops.aten.batch_norm.default, F.batch_norm],
531+
[
532+
torch.ops.aten.batch_norm.default,
533+
],
525534
],
526535
):
527536
if node.target in (
@@ -534,7 +543,9 @@ def any_or_hardtanh_min_zero(n: Node):
534543
_QuantProperty(1, weight_qspec, mark_annotated=True),
535544
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
536545
]
537-
elif node.target in [torch.ops.aten.batch_norm.default, F.batch_norm]:
546+
elif node.target in [
547+
torch.ops.aten.batch_norm.default,
548+
]:
538549
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
539550
elif _match_pattern(
540551
node,
@@ -549,6 +560,7 @@ def any_or_hardtanh_min_zero(n: Node):
549560
torch.ops.aten.relu.default,
550561
torch.ops.aten.relu_.default,
551562
torch.ops.aten.hardtanh.default,
563+
torch.ops.aten.hardtanh_.default,
552564
],
553565
],
554566
any_or_hardtanh_min_zero,

backends/arm/test/misc/test_bn_relu_folding_qat.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
from typing import Tuple
77

88
import torch
9-
import torch.nn.functional as F
109
from executorch.backends.arm.quantizer.arm_quantizer import (
1110
get_symmetric_quantization_config,
1211
TOSAQuantizer,
1312
)
14-
from executorch.backends.arm.test import common, conftest
13+
from executorch.backends.arm.test import common
1514
from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT
15+
from executorch.backends.arm.tosa import TosaSpecification
1616

1717
from executorch.backends.xnnpack.test.tester.tester import Quantize
1818
from torch import nn
@@ -21,51 +21,104 @@
2121
input_t1 = Tuple[torch.Tensor] # Input x
2222

2323

24-
class ConvModule(torch.nn.Module):
24+
class Conv2dModule(torch.nn.Module):
2525
input_shape = (1, 28, 28)
2626
batch_size = 64
2727
test_data: input_t1 = (torch.randn(batch_size, *input_shape),)
2828

29-
def __init__(self, batch_norm: bool = True) -> None:
29+
def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None:
3030
super().__init__()
3131
self.conv = torch.nn.Conv2d(1, 16, 3, stride=2)
3232
self.bn = nn.BatchNorm2d(num_features=16) if batch_norm else nn.Identity()
33+
self.relu = nn.ReLU(inplace=inplace)
3334

3435
def forward(self, x: torch.Tensor):
3536
x = self.conv(x)
3637
x = self.bn(x)
37-
x = F.relu(x)
38+
x = self.relu(x)
39+
40+
return x
41+
42+
43+
class Conv1dModule(torch.nn.Module):
44+
input_shape = (3, 10)
45+
batch_size = 2
46+
test_data: input_t1 = (torch.randn(batch_size, *input_shape),)
47+
48+
def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None:
49+
super().__init__()
50+
self.conv = torch.nn.Conv1d(3, 8, 5, padding=2)
51+
self.bn = nn.BatchNorm1d(num_features=8) if batch_norm else nn.Identity()
52+
self.relu = nn.ReLU(inplace=inplace)
53+
54+
def forward(self, x: torch.Tensor):
55+
x = self.conv(x)
56+
x = self.bn(x)
57+
x = self.relu(x)
3858

3959
return x
4060

4161

4262
models = {
4363
# name : (model, is_per_channel)
44-
"conv_bn_relu_per_channel": (ConvModule(batch_norm=True), True),
45-
"conv_relu_per_channel": (ConvModule(batch_norm=False), True),
46-
"conv_bn_relu_per_tensor": (ConvModule(batch_norm=True), False),
47-
"conv_relu_per_tensor": (ConvModule(batch_norm=False), False),
64+
"conv1d_bn_relu_per_channel": (Conv1dModule(batch_norm=True), True),
65+
"conv1d_relu_per_channel": (Conv1dModule(batch_norm=False), True),
66+
"conv1d_bn_relu_per_tensor": (Conv1dModule(batch_norm=True), False),
67+
"conv1d_relu_per_tensor": (Conv1dModule(batch_norm=False), False),
68+
"conv2d_bn_relu_per_channel": (Conv2dModule(batch_norm=True), True),
69+
"conv2d_relu_per_channel": (Conv2dModule(batch_norm=False), True),
70+
"conv2d_bn_relu_per_tensor": (Conv2dModule(batch_norm=True), False),
71+
"conv2d_relu_per_tensor": (Conv2dModule(batch_norm=False), False),
72+
"conv1d_bn_relu_inplace_per_channel": (
73+
Conv1dModule(batch_norm=True, inplace=True),
74+
True,
75+
),
76+
"conv1d_relu_inplace_per_channel": (
77+
Conv1dModule(batch_norm=False, inplace=True),
78+
True,
79+
),
80+
"conv1d_bn_relu_inplace_per_tensor": (
81+
Conv1dModule(batch_norm=True, inplace=True),
82+
False,
83+
),
84+
"conv1d_relu_inplace_per_tensor": (
85+
Conv1dModule(batch_norm=False, inplace=True),
86+
False,
87+
),
88+
"conv2d_bn_relu_inplace_per_channel": (
89+
Conv2dModule(batch_norm=True, inplace=True),
90+
True,
91+
),
92+
"conv2d_relu_inplace_per_channel": (
93+
Conv2dModule(batch_norm=False, inplace=True),
94+
True,
95+
),
96+
"conv2d_bn_relu_inplace_per_tensor": (
97+
Conv2dModule(batch_norm=True, inplace=True),
98+
False,
99+
),
100+
"conv2d_relu_inplace_per_tensor": (
101+
Conv2dModule(batch_norm=False, inplace=True),
102+
False,
103+
),
48104
}
49105

50106

51-
@common.parametrize("test_data", models)
107+
@common.parametrize(
108+
"test_data",
109+
models,
110+
)
52111
def test_qat_tosa_INT(test_data):
53112
model, per_channel = test_data
54113
pipeline = TosaPipelineINT[input_t1](model, model.test_data, [], [], qtol=1)
55-
tosa_version = conftest.get_option("tosa_version")
56-
tosa_profiles = {
57-
"1.0": common.TosaSpecification.create_from_string("TOSA-1.0+INT"),
58-
}
59-
tosa_spec = tosa_profiles[tosa_version]
60-
quantizer = TOSAQuantizer(tosa_spec)
114+
quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT"))
61115
pipeline.change_args(
62116
"quantize",
63117
Quantize(
64118
quantizer=quantizer,
65119
quantization_config=get_symmetric_quantization_config(
66120
is_qat=True, is_per_channel=per_channel
67121
),
68-
is_qat=True,
69122
),
70123
)
71124
pipeline.run()

0 commit comments

Comments
 (0)