|
6 | 6 | from typing import Tuple |
7 | 7 |
|
8 | 8 | import torch |
9 | | -import torch.nn.functional as F |
10 | 9 | from executorch.backends.arm.quantizer.arm_quantizer import ( |
11 | 10 | get_symmetric_quantization_config, |
12 | 11 | TOSAQuantizer, |
13 | 12 | ) |
14 | | -from executorch.backends.arm.test import common, conftest |
| 13 | +from executorch.backends.arm.test import common |
15 | 14 | from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineINT |
| 15 | +from executorch.backends.arm.tosa import TosaSpecification |
16 | 16 |
|
17 | 17 | from executorch.backends.xnnpack.test.tester.tester import Quantize |
18 | 18 | from torch import nn |
|
21 | 21 | input_t1 = Tuple[torch.Tensor] # Input x |
22 | 22 |
|
23 | 23 |
|
24 | | -class ConvModule(torch.nn.Module): |
| 24 | +class Conv2dModule(torch.nn.Module): |
25 | 25 | input_shape = (1, 28, 28) |
26 | 26 | batch_size = 64 |
27 | 27 | test_data: input_t1 = (torch.randn(batch_size, *input_shape),) |
28 | 28 |
|
29 | | - def __init__(self, batch_norm: bool = True) -> None: |
| 29 | + def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None: |
30 | 30 | super().__init__() |
31 | 31 | self.conv = torch.nn.Conv2d(1, 16, 3, stride=2) |
32 | 32 | self.bn = nn.BatchNorm2d(num_features=16) if batch_norm else nn.Identity() |
| 33 | + self.relu = nn.ReLU(inplace=inplace) |
33 | 34 |
|
34 | 35 | def forward(self, x: torch.Tensor): |
35 | 36 | x = self.conv(x) |
36 | 37 | x = self.bn(x) |
37 | | - x = F.relu(x) |
| 38 | + x = self.relu(x) |
| 39 | + |
| 40 | + return x |
| 41 | + |
| 42 | + |
| 43 | +class Conv1dModule(torch.nn.Module): |
| 44 | + input_shape = (3, 10) |
| 45 | + batch_size = 2 |
| 46 | + test_data: input_t1 = (torch.randn(batch_size, *input_shape),) |
| 47 | + |
| 48 | + def __init__(self, batch_norm: bool = True, inplace: bool = False) -> None: |
| 49 | + super().__init__() |
| 50 | + self.conv = torch.nn.Conv1d(3, 8, 5, padding=2) |
| 51 | + self.bn = nn.BatchNorm1d(num_features=8) if batch_norm else nn.Identity() |
| 52 | + self.relu = nn.ReLU(inplace=inplace) |
| 53 | + |
| 54 | + def forward(self, x: torch.Tensor): |
| 55 | + x = self.conv(x) |
| 56 | + x = self.bn(x) |
| 57 | + x = self.relu(x) |
38 | 58 |
|
39 | 59 | return x |
40 | 60 |
|
41 | 61 |
|
42 | 62 | models = { |
43 | 63 | # name : (model, is_per_channel) |
44 | | - "conv_bn_relu_per_channel": (ConvModule(batch_norm=True), True), |
45 | | - "conv_relu_per_channel": (ConvModule(batch_norm=False), True), |
46 | | - "conv_bn_relu_per_tensor": (ConvModule(batch_norm=True), False), |
47 | | - "conv_relu_per_tensor": (ConvModule(batch_norm=False), False), |
| 64 | + "conv1d_bn_relu_per_channel": (Conv1dModule(batch_norm=True), True), |
| 65 | + "conv1d_relu_per_channel": (Conv1dModule(batch_norm=False), True), |
| 66 | + "conv1d_bn_relu_per_tensor": (Conv1dModule(batch_norm=True), False), |
| 67 | + "conv1d_relu_per_tensor": (Conv1dModule(batch_norm=False), False), |
| 68 | + "conv2d_bn_relu_per_channel": (Conv2dModule(batch_norm=True), True), |
| 69 | + "conv2d_relu_per_channel": (Conv2dModule(batch_norm=False), True), |
| 70 | + "conv2d_bn_relu_per_tensor": (Conv2dModule(batch_norm=True), False), |
| 71 | + "conv2d_relu_per_tensor": (Conv2dModule(batch_norm=False), False), |
| 72 | + "conv1d_bn_relu_inplace_per_channel": ( |
| 73 | + Conv1dModule(batch_norm=True, inplace=True), |
| 74 | + True, |
| 75 | + ), |
| 76 | + "conv1d_relu_inplace_per_channel": ( |
| 77 | + Conv1dModule(batch_norm=False, inplace=True), |
| 78 | + True, |
| 79 | + ), |
| 80 | + "conv1d_bn_relu_inplace_per_tensor": ( |
| 81 | + Conv1dModule(batch_norm=True, inplace=True), |
| 82 | + False, |
| 83 | + ), |
| 84 | + "conv1d_relu_inplace_per_tensor": ( |
| 85 | + Conv1dModule(batch_norm=False, inplace=True), |
| 86 | + False, |
| 87 | + ), |
| 88 | + "conv2d_bn_relu_inplace_per_channel": ( |
| 89 | + Conv2dModule(batch_norm=True, inplace=True), |
| 90 | + True, |
| 91 | + ), |
| 92 | + "conv2d_relu_inplace_per_channel": ( |
| 93 | + Conv2dModule(batch_norm=False, inplace=True), |
| 94 | + True, |
| 95 | + ), |
| 96 | + "conv2d_bn_relu_inplace_per_tensor": ( |
| 97 | + Conv2dModule(batch_norm=True, inplace=True), |
| 98 | + False, |
| 99 | + ), |
| 100 | + "conv2d_relu_inplace_per_tensor": ( |
| 101 | + Conv2dModule(batch_norm=False, inplace=True), |
| 102 | + False, |
| 103 | + ), |
48 | 104 | } |
49 | 105 |
|
50 | 106 |
|
51 | | -@common.parametrize("test_data", models) |
| 107 | +@common.parametrize( |
| 108 | + "test_data", |
| 109 | + models, |
| 110 | +) |
52 | 111 | def test_qat_tosa_INT(test_data): |
53 | 112 | model, per_channel = test_data |
54 | 113 | pipeline = TosaPipelineINT[input_t1](model, model.test_data, [], [], qtol=1) |
55 | | - tosa_version = conftest.get_option("tosa_version") |
56 | | - tosa_profiles = { |
57 | | - "1.0": common.TosaSpecification.create_from_string("TOSA-1.0+INT"), |
58 | | - } |
59 | | - tosa_spec = tosa_profiles[tosa_version] |
60 | | - quantizer = TOSAQuantizer(tosa_spec) |
| 114 | + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) |
61 | 115 | pipeline.change_args( |
62 | 116 | "quantize", |
63 | 117 | Quantize( |
64 | 118 | quantizer=quantizer, |
65 | 119 | quantization_config=get_symmetric_quantization_config( |
66 | 120 | is_qat=True, is_per_channel=per_channel |
67 | 121 | ), |
68 | | - is_qat=True, |
69 | 122 | ), |
70 | 123 | ) |
71 | 124 | pipeline.run() |
0 commit comments