From 033580deabb7917d0cf71eb88ed776f4818f6bcc Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 18 Nov 2025 15:25:41 +0000 Subject: [PATCH] Arm backend: Add int16x8 LayerNorm test cases - Updates test_rsqrt to use a lower epsilon - Adds epsilon parameter to test_pipeline.py --- backends/arm/test/ops/test_layer_norm.py | 47 ++++++++++++++ backends/arm/test/ops/test_rsqrt.py | 79 ++++------------------- backends/arm/test/tester/test_pipeline.py | 9 ++- 3 files changed, 67 insertions(+), 68 deletions(-) diff --git a/backends/arm/test/ops/test_layer_norm.py b/backends/arm/test/ops/test_layer_norm.py index 2659bc2eab4..f3f11959d4d 100644 --- a/backends/arm/test/ops/test_layer_norm.py +++ b/backends/arm/test/ops/test_layer_norm.py @@ -137,3 +137,50 @@ def test_native_layer_norm_vgf_INT(test_data): tosa_version="TOSA-1.0+INT", ) pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_native_layer_norm_tosa_INT_a16w8(test_data): + """Test layer_norm with int16 I/O quantization for TOSA INT.""" + test_input, model = test_data() + pipeline = TosaPipelineINT[input_t]( + model, + test_input, + "torch.ops.aten.sub.Tensor", # check for sub op in decomposition + symmetric_io_quantization=True, + tosa_extensions=["int16"], + epsilon=2**16, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_native_layer_norm_16a8w_u55_INT16(test_data): + """Test layer_norm with int16 I/O quantization for U55""" + test_input, model = test_data() + pipeline = EthosU55PipelineINT[input_t]( + model, + test_input, + "torch.ops.aten.sub.Tensor", + symmetric_io_quantization=True, + a16w8_quantization=True, + epsilon=2**16, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_native_layer_norm_16a8w_u85_INT16(test_data): + """Test layer_norm with int16 I/O quantization for U85""" + test_input, model = test_data() + pipeline = EthosU85PipelineINT[input_t]( + model, + test_input, + "torch.ops.aten.sub.Tensor", + symmetric_io_quantization=True, + a16w8_quantization=True, + epsilon=2**16, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_rsqrt.py b/backends/arm/test/ops/test_rsqrt.py index 9e2f024dcdd..0fea7ba2ec0 100644 --- a/backends/arm/test/ops/test_rsqrt.py +++ b/backends/arm/test/ops/test_rsqrt.py @@ -10,11 +10,8 @@ import pytest import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - get_symmetric_a16w8_quantization_config, - TOSAQuantizer, -) -from executorch.backends.arm.test import common, conftest + +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -23,8 +20,6 @@ TosaPipelineINT, VgfPipeline, ) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.rsqrt.default" input_t1 = Tuple[torch.Tensor] # Input x @@ -112,48 +107,18 @@ def test_rsqrt_vgf_INT(test_tensor: torch.Tensor): pipeline.run() -def get_symmetric_a16w8_rsqrt_quantizer( - u55_config=False, per_channel_quantization=False -): - tosa_version = conftest.get_option("tosa_version") - tosa_profiles = { - "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), - } - - quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) - quantizer.set_global( - get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) - ) - - return Quantize( - quantizer, - get_symmetric_a16w8_quantization_config( - is_per_channel=per_channel_quantization - ), - ) - - @common.parametrize("test_tensor", Rsqrt.test_parameters) -@pytest.mark.xfail( - reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." -) -def test_rsqrt_16a8w_tosa_INT(test_tensor: torch.Tensor): - """Test rsqrt operation with int16 quantization""" +def test_rsqrt_tosa_INT_a16w8(test_tensor: torch.Tensor): + """Test rsqrt operation with int16 I/O quantization for TOSA INT.""" + # Use wider tolerances for int16 I/O quantization pipeline = TosaPipelineINT[input_t1]( Rsqrt(), test_tensor(), aten_op, exir_op=[], - per_channel_quantization=False, - use_to_edge_transform_and_lower=True, tosa_extensions=["int16"], + epsilon=2**16, ) - - pipeline.change_args( - "quantize", - get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=False), - ) - # Run the pipeline pipeline.run() @@ -163,46 +128,30 @@ def test_rsqrt_16a8w_tosa_INT(test_tensor: torch.Tensor): reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." ) def test_rsqrt_16a8w_u55_INT16(test_tensor: torch.Tensor): - """Test rsqrt operation with int16 quantization on U55""" + """Test rsqrt operation with int16 I/O quantization for U55""" + # Use wider tolerances for int16 I/O quantization on U55 pipeline = EthosU55PipelineINT[input_t1]( Rsqrt(), test_tensor(), aten_op, exir_ops=[], - per_channel_quantization=True, - use_to_edge_transform_and_lower=True, - atol=1e-03, - rtol=1e-03, - run_on_fvp=True, - ) - - pipeline.change_args( - "quantize", - get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=True), + a16w8_quantization=True, + epsilon=2**16, ) pipeline.run() @common.parametrize("test_tensor", Rsqrt.test_parameters) @common.XfailIfNoCorstone320 -@pytest.mark.xfail( - reason="MLETORCH-707: AssertionError: Output 0 does not match reference output." -) def test_rsqrt_16a8w_u85_INT16(test_tensor: torch.Tensor): - """Test rsqrt operation with int16 quantization on U85""" + """Test rsqrt operation with int16 I/O quantization for U85""" + # Use wider tolerances for int16 I/O quantization on U85 pipeline = EthosU85PipelineINT[input_t1]( Rsqrt(), test_tensor(), aten_op, exir_ops=[], - use_to_edge_transform_and_lower=True, - atol=1e-03, - rtol=1e-03, - run_on_fvp=True, - ) - - pipeline.change_args( - "quantize", - get_symmetric_a16w8_rsqrt_quantizer(per_channel_quantization=False), + a16w8_quantization=True, + epsilon=2**16, ) pipeline.run() diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 824f13417b2..bee035bd775 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -357,6 +357,7 @@ def __init__( qtol: int = 1, dynamic_shapes: Optional[Tuple[Any]] = None, tosa_extensions: Optional[List[str]] = None, + epsilon: float = 2**12, ): if tosa_extensions is None: tosa_extensions = [] @@ -377,7 +378,7 @@ def __init__( # choose 16A8W quantization config when int16 extension is requested if "int16" in tosa_extensions: quantization_config = get_symmetric_a16w8_quantization_config( - is_per_channel=per_channel_quantization + is_per_channel=per_channel_quantization, epsilon=epsilon ) else: quantization_config = get_symmetric_quantization_config( @@ -550,6 +551,7 @@ def __init__( atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, + epsilon: float = 2**12, ): compile_spec = common.get_u55_compile_spec( custom_path=custom_path, @@ -559,7 +561,7 @@ def __init__( # choose int8 or int16 activation quantization if a16w8_quantization: quantization_config = get_symmetric_a16w8_quantization_config( - is_per_channel=per_channel_quantization + is_per_channel=per_channel_quantization, epsilon=epsilon ) else: quantization_config = get_symmetric_quantization_config( @@ -650,6 +652,7 @@ def __init__( atol: float = 1e-03, rtol: float = 1e-03, qtol: int = 1, + epsilon: float = 2**12, ): compile_spec = common.get_u85_compile_spec( custom_path=custom_path, @@ -659,7 +662,7 @@ def __init__( # choose int8 or int16 activation quantization if a16w8_quantization: quantization_config = get_symmetric_a16w8_quantization_config( - is_per_channel=per_channel_quantization + is_per_channel=per_channel_quantization, epsilon=epsilon ) else: quantization_config = get_symmetric_quantization_config(