Skip to content

Commit f255a99

Browse files
authored
Arm backend: Add int16x8 LayerNorm test cases (#16015)
### Summary - Updates test_rsqrt to use a lower epsilon - Adds epsilon parameter to test_pipeline.py
1 parent 94d96a1 commit f255a99

File tree

3 files changed

+67
-68
lines changed

3 files changed

+67
-68
lines changed

backends/arm/test/ops/test_layer_norm.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,50 @@ def test_native_layer_norm_vgf_INT(test_data):
137137
tosa_version="TOSA-1.0+INT",
138138
)
139139
pipeline.run()
140+
141+
142+
@common.parametrize("test_data", test_data_suite)
143+
def test_native_layer_norm_tosa_INT_a16w8(test_data):
144+
"""Test layer_norm with int16 I/O quantization for TOSA INT."""
145+
test_input, model = test_data()
146+
pipeline = TosaPipelineINT[input_t](
147+
model,
148+
test_input,
149+
"torch.ops.aten.sub.Tensor", # check for sub op in decomposition
150+
symmetric_io_quantization=True,
151+
tosa_extensions=["int16"],
152+
epsilon=2**16,
153+
)
154+
pipeline.run()
155+
156+
157+
@common.parametrize("test_data", test_data_suite)
158+
@common.XfailIfNoCorstone300
159+
def test_native_layer_norm_16a8w_u55_INT16(test_data):
160+
"""Test layer_norm with int16 I/O quantization for U55"""
161+
test_input, model = test_data()
162+
pipeline = EthosU55PipelineINT[input_t](
163+
model,
164+
test_input,
165+
"torch.ops.aten.sub.Tensor",
166+
symmetric_io_quantization=True,
167+
a16w8_quantization=True,
168+
epsilon=2**16,
169+
)
170+
pipeline.run()
171+
172+
173+
@common.parametrize("test_data", test_data_suite)
174+
@common.XfailIfNoCorstone320
175+
def test_native_layer_norm_16a8w_u85_INT16(test_data):
176+
"""Test layer_norm with int16 I/O quantization for U85"""
177+
test_input, model = test_data()
178+
pipeline = EthosU85PipelineINT[input_t](
179+
model,
180+
test_input,
181+
"torch.ops.aten.sub.Tensor",
182+
symmetric_io_quantization=True,
183+
a16w8_quantization=True,
184+
epsilon=2**16,
185+
)
186+
pipeline.run()

backends/arm/test/ops/test_rsqrt.py

Lines changed: 14 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,8 @@
1010

1111
import pytest
1212
import torch
13-
from executorch.backends.arm.quantizer.arm_quantizer import (
14-
get_symmetric_a16w8_quantization_config,
15-
TOSAQuantizer,
16-
)
17-
from executorch.backends.arm.test import common, conftest
13+
14+
from executorch.backends.arm.test import common
1815

1916
from executorch.backends.arm.test.tester.test_pipeline import (
2017
EthosU55PipelineINT,
@@ -23,8 +20,6 @@
2320
TosaPipelineINT,
2421
VgfPipeline,
2522
)
26-
from executorch.backends.arm.tosa import TosaSpecification
27-
from executorch.backends.xnnpack.test.tester import Quantize
2823

2924
aten_op = "torch.ops.aten.rsqrt.default"
3025
input_t1 = Tuple[torch.Tensor] # Input x
@@ -112,48 +107,18 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
112107
pipeline.run()
113108

114109

115-
def get_symmetric_a16w8_rsqrt_quantizer(
116-
u55_config=False, per_channel_quantization=False
117-
):
118-
tosa_version = conftest.get_option("tosa_version")
119-
tosa_profiles = {
120-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
121-
}
122-
123-
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
124-
quantizer.set_global(
125-
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
126-
)
127-
128-
return Quantize(
129-
quantizer,
130-
get_symmetric_a16w8_quantization_config(
131-
is_per_channel=per_channel_quantization
132-
),
133-
)
134-
135-
136110
@common.parametrize("test_tensor", Rsqrt.test_parameters)
137-
@pytest.mark.xfail(
138-
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
139-
)
140-
def test_rsqrt_16a8w_tosa_INT(test_tensor: torch.Tensor):
141-
"""Test rsqrt operation with int16 quantization"""
111+
def test_rsqrt_tosa_INT_a16w8(test_tensor: torch.Tensor):
112+
"""Test rsqrt operation with int16 I/O quantization for TOSA INT."""
113+
# Use wider tolerances for int16 I/O quantization
142114
pipeline = TosaPipelineINT[input_t1](
143115
Rsqrt(),
144116
test_tensor(),
145117
aten_op,
146118
exir_op=[],
147-
per_channel_quantization=False,
148-
use_to_edge_transform_and_lower=True,
149119
tosa_extensions=["int16"],
120+
epsilon=2**16,
150121
)
151-
152-
pipeline.change_args(
153-
"quantize",
154-
get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=False),
155-
)
156-
# Run the pipeline
157122
pipeline.run()
158123

159124

@@ -163,46 +128,30 @@ def test_rsqrt_16a8w_tosa_INT(test_tensor: torch.Tensor):
163128
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
164129
)
165130
def test_rsqrt_16a8w_u55_INT16(test_tensor: torch.Tensor):
166-
"""Test rsqrt operation with int16 quantization on U55"""
131+
"""Test rsqrt operation with int16 I/O quantization for U55"""
132+
# Use wider tolerances for int16 I/O quantization on U55
167133
pipeline = EthosU55PipelineINT[input_t1](
168134
Rsqrt(),
169135
test_tensor(),
170136
aten_op,
171137
exir_ops=[],
172-
per_channel_quantization=True,
173-
use_to_edge_transform_and_lower=True,
174-
atol=1e-03,
175-
rtol=1e-03,
176-
run_on_fvp=True,
177-
)
178-
179-
pipeline.change_args(
180-
"quantize",
181-
get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=True),
138+
a16w8_quantization=True,
139+
epsilon=2**16,
182140
)
183141
pipeline.run()
184142

185143

186144
@common.parametrize("test_tensor", Rsqrt.test_parameters)
187145
@common.XfailIfNoCorstone320
188-
@pytest.mark.xfail(
189-
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
190-
)
191146
def test_rsqrt_16a8w_u85_INT16(test_tensor: torch.Tensor):
192-
"""Test rsqrt operation with int16 quantization on U85"""
147+
"""Test rsqrt operation with int16 I/O quantization for U85"""
148+
# Use wider tolerances for int16 I/O quantization on U85
193149
pipeline = EthosU85PipelineINT[input_t1](
194150
Rsqrt(),
195151
test_tensor(),
196152
aten_op,
197153
exir_ops=[],
198-
use_to_edge_transform_and_lower=True,
199-
atol=1e-03,
200-
rtol=1e-03,
201-
run_on_fvp=True,
202-
)
203-
204-
pipeline.change_args(
205-
"quantize",
206-
get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=False),
154+
a16w8_quantization=True,
155+
epsilon=2**16,
207156
)
208157
pipeline.run()

backends/arm/test/tester/test_pipeline.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def __init__(
357357
qtol: int = 1,
358358
dynamic_shapes: Optional[Tuple[Any]] = None,
359359
tosa_extensions: Optional[List[str]] = None,
360+
epsilon: float = 2**12,
360361
):
361362
if tosa_extensions is None:
362363
tosa_extensions = []
@@ -377,7 +378,7 @@ def __init__(
377378
# choose 16A8W quantization config when int16 extension is requested
378379
if "int16" in tosa_extensions:
379380
quantization_config = get_symmetric_a16w8_quantization_config(
380-
is_per_channel=per_channel_quantization
381+
is_per_channel=per_channel_quantization, epsilon=epsilon
381382
)
382383
else:
383384
quantization_config = get_symmetric_quantization_config(
@@ -550,6 +551,7 @@ def __init__(
550551
atol: float = 1e-03,
551552
rtol: float = 1e-03,
552553
qtol: int = 1,
554+
epsilon: float = 2**12,
553555
):
554556
compile_spec = common.get_u55_compile_spec(
555557
custom_path=custom_path,
@@ -559,7 +561,7 @@ def __init__(
559561
# choose int8 or int16 activation quantization
560562
if a16w8_quantization:
561563
quantization_config = get_symmetric_a16w8_quantization_config(
562-
is_per_channel=per_channel_quantization
564+
is_per_channel=per_channel_quantization, epsilon=epsilon
563565
)
564566
else:
565567
quantization_config = get_symmetric_quantization_config(
@@ -650,6 +652,7 @@ def __init__(
650652
atol: float = 1e-03,
651653
rtol: float = 1e-03,
652654
qtol: int = 1,
655+
epsilon: float = 2**12,
653656
):
654657
compile_spec = common.get_u85_compile_spec(
655658
custom_path=custom_path,
@@ -659,7 +662,7 @@ def __init__(
659662
# choose int8 or int16 activation quantization
660663
if a16w8_quantization:
661664
quantization_config = get_symmetric_a16w8_quantization_config(
662-
is_per_channel=per_channel_quantization
665+
is_per_channel=per_channel_quantization, epsilon=epsilon
663666
)
664667
else:
665668
quantization_config = get_symmetric_quantization_config(

0 commit comments

Comments
 (0)