1010
1111import pytest
1212import torch
13- from executorch .backends .arm .quantizer .arm_quantizer import (
14- get_symmetric_a16w8_quantization_config ,
15- TOSAQuantizer ,
16- )
17- from executorch .backends .arm .test import common , conftest
13+
14+ from executorch .backends .arm .test import common
1815
1916from executorch .backends .arm .test .tester .test_pipeline import (
2017 EthosU55PipelineINT ,
2320 TosaPipelineINT ,
2421 VgfPipeline ,
2522)
26- from executorch .backends .arm .tosa import TosaSpecification
27- from executorch .backends .xnnpack .test .tester import Quantize
2823
2924aten_op = "torch.ops.aten.rsqrt.default"
3025input_t1 = Tuple [torch .Tensor ] # Input x
@@ -112,48 +107,18 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor):
112107 pipeline .run ()
113108
114109
115- def get_symmetric_a16w8_rsqrt_quantizer (
116- u55_config = False , per_channel_quantization = False
117- ):
118- tosa_version = conftest .get_option ("tosa_version" )
119- tosa_profiles = {
120- "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT+int16" ),
121- }
122-
123- quantizer = TOSAQuantizer (tosa_profiles [tosa_version ])
124- quantizer .set_global (
125- get_symmetric_a16w8_quantization_config (is_per_channel = per_channel_quantization )
126- )
127-
128- return Quantize (
129- quantizer ,
130- get_symmetric_a16w8_quantization_config (
131- is_per_channel = per_channel_quantization
132- ),
133- )
134-
135-
136110@common .parametrize ("test_tensor" , Rsqrt .test_parameters )
137- @pytest .mark .xfail (
138- reason = "MLETORCH-707: AssertionError: Output 0 does not match reference output."
139- )
140- def test_rsqrt_16a8w_tosa_INT (test_tensor : torch .Tensor ):
141- """Test rsqrt operation with int16 quantization"""
111+ def test_rsqrt_tosa_INT_a16w8 (test_tensor : torch .Tensor ):
112+ """Test rsqrt operation with int16 I/O quantization for TOSA INT."""
113+ # Use wider tolerances for int16 I/O quantization
142114 pipeline = TosaPipelineINT [input_t1 ](
143115 Rsqrt (),
144116 test_tensor (),
145117 aten_op ,
146118 exir_op = [],
147- per_channel_quantization = False ,
148- use_to_edge_transform_and_lower = True ,
149119 tosa_extensions = ["int16" ],
120+ epsilon = 2 ** 16 ,
150121 )
151-
152- pipeline .change_args (
153- "quantize" ,
154- get_symmetric_a16w8_rsqrt_quantizer (per_channel_quantization = False ),
155- )
156- # Run the pipeline
157122 pipeline .run ()
158123
159124
@@ -163,46 +128,30 @@ def test_rsqrt_16a8w_tosa_INT(test_tensor: torch.Tensor):
163128 reason = "MLETORCH-707: AssertionError: Output 0 does not match reference output."
164129)
165130def test_rsqrt_16a8w_u55_INT16 (test_tensor : torch .Tensor ):
166- """Test rsqrt operation with int16 quantization on U55"""
131+ """Test rsqrt operation with int16 I/O quantization for U55"""
132+ # Use wider tolerances for int16 I/O quantization on U55
167133 pipeline = EthosU55PipelineINT [input_t1 ](
168134 Rsqrt (),
169135 test_tensor (),
170136 aten_op ,
171137 exir_ops = [],
172- per_channel_quantization = True ,
173- use_to_edge_transform_and_lower = True ,
174- atol = 1e-03 ,
175- rtol = 1e-03 ,
176- run_on_fvp = True ,
177- )
178-
179- pipeline .change_args (
180- "quantize" ,
181- get_symmetric_a16w8_rsqrt_quantizer (per_channel_quantization = True ),
138+ a16w8_quantization = True ,
139+ epsilon = 2 ** 16 ,
182140 )
183141 pipeline .run ()
184142
185143
186144@common .parametrize ("test_tensor" , Rsqrt .test_parameters )
187145@common .XfailIfNoCorstone320
188- @pytest .mark .xfail (
189- reason = "MLETORCH-707: AssertionError: Output 0 does not match reference output."
190- )
191146def test_rsqrt_16a8w_u85_INT16 (test_tensor : torch .Tensor ):
192- """Test rsqrt operation with int16 quantization on U85"""
147+ """Test rsqrt operation with int16 I/O quantization for U85"""
148+ # Use wider tolerances for int16 I/O quantization on U85
193149 pipeline = EthosU85PipelineINT [input_t1 ](
194150 Rsqrt (),
195151 test_tensor (),
196152 aten_op ,
197153 exir_ops = [],
198- use_to_edge_transform_and_lower = True ,
199- atol = 1e-03 ,
200- rtol = 1e-03 ,
201- run_on_fvp = True ,
202- )
203-
204- pipeline .change_args (
205- "quantize" ,
206- get_symmetric_a16w8_rsqrt_quantizer (per_channel_quantization = False ),
154+ a16w8_quantization = True ,
155+ epsilon = 2 ** 16 ,
207156 )
208157 pipeline .run ()
0 commit comments