From fa596ab71b6d1495615fff080dda2112cf4dc01f Mon Sep 17 00:00:00 2001 From: Martin Pavella Date: Thu, 21 Aug 2025 11:23:48 +0200 Subject: [PATCH 1/2] NXP backend: Use zero_point to pad quantized convolution. --- .../ops_converters/convolution_converter.py | 35 +++++++++++++++---- .../node_converters/shared/conv_utils.py | 15 ++++++-- .../node_converter/test_conv_converter.py | 8 ++++- 3 files changed, 49 insertions(+), 9 deletions(-) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py index 653fc577c73..073c19f8871 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py @@ -5,6 +5,8 @@ import numpy as np import torch +from torch.fx import Node +from torch.nn import Parameter from executorch.backends.nxp.backend.edge_helper import ( input_tensor, @@ -16,6 +18,9 @@ common, ) from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input +from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + tf_lite_type_to_numpy, +) from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, @@ -38,8 +43,6 @@ conv_2d_options, depthwise_conv_2d_options, ) -from torch.fx import Node -from torch.nn import Parameter class ConvolutionConverter(NodeConverter): @@ -188,9 +191,19 @@ def _convert_2d_conv( aten_translator.convert_padding(conv_params.padding) ) if explicit_padding is not None: - # Need to prepend a 'Pad' operator, which adds 0s. + # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). + input_quantization = t_op.tmp_inputs[0].quantization + pad_value = ( + None + if input_quantization is None + else np.array(input_quantization.zero_point[0]).astype( + tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) + ) + ) conversion_result.ops_list.add_pre( - self.builder.create_pad_operator_before(t_op, 0, explicit_padding) + self.builder.create_pad_operator_before( + t_op, 0, explicit_padding, constant_value=pad_value + ) ) # DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels] @@ -227,9 +240,19 @@ def _convert_2d_conv( aten_translator.convert_padding(conv_params.padding) ) if explicit_padding is not None: - # Need to prepend a 'Pad' operator, which adds 0s. + # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). + input_quantization = t_op.tmp_inputs[0].quantization + pad_value = ( + None + if input_quantization is None + else np.array(input_quantization.zero_point[0]).astype( + tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) + ) + ) conversion_result.ops_list.add_pre( - self.builder.create_pad_operator_before(t_op, 0, explicit_padding) + self.builder.create_pad_operator_before( + t_op, 0, explicit_padding, constant_value=pad_value + ) ) return conversion_result.ops_list.flatten() diff --git a/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py b/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py index ce03d4f6f15..3422e214982 100755 --- a/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py +++ b/backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py @@ -14,6 +14,9 @@ ) from executorch.backends.nxp.backend.ir.converter.conversion import aten_translator from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList +from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + tf_lite_type_to_numpy, +) from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data from executorch.backends.nxp.backend.ir.lib.tflite.Padding import Padding from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model @@ -289,9 +292,17 @@ def build_input_tensor_padding( tfl_padding, explicit_padding = aten_translator.convert_padding(conv_params.padding) if explicit_padding is not None: - # Must add extra 'Pad' operator + # Must add extra 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). + input_quantization = t_op.tmp_inputs[0].quantization + pad_value = ( + None + if input_quantization is None + else np.array(input_quantization.zero_point[0]).astype( + tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) + ) + ) return tfl_padding, builder.create_pad_operator_before( - t_op, input_idx, explicit_padding + t_op, input_idx, explicit_padding, pad_value ) return tfl_padding, None diff --git a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py index eb2818570f1..b116e909cb5 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py @@ -326,7 +326,7 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): ops = spy.spy_return.sub_graphs[0].operators.vector assert len(ops) == 2 - assert ops[0].builtin_options.operator_type == BuiltinOperator.PAD + assert ops[0].builtin_options.operator_type == BuiltinOperator.PADV2 assert ops[1].builtin_options.operator_type == BuiltinOperator.DEPTHWISE_CONV_2D nodes = list(edge_program.graph.nodes) @@ -335,6 +335,12 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker): ) # input, Quant, lowered_module, delegate_call, getitem, Deq, output assert nodes[2].target == "lowered_module_0" + # Make sure the padding used the `zero-point`. + assert ( + ops[0].tmp_inputs[2].tmp_buffer.data.item() + == ops[0].tmp_outputs[0].quantization.zero_point[0] + ) + @pytest.mark.parametrize("stride", [1, 2]) @pytest.mark.parametrize("dilation", [1, 2]) From 54ef48c840084a9857e58c4f78a63c170b19889a Mon Sep 17 00:00:00 2001 From: Martin Pavella Date: Wed, 13 Aug 2025 14:35:14 +0200 Subject: [PATCH 2/2] NXP backend: Use zero_point to pad quantized average_pool. --- .../ops_converters/avg_pool_2d_converter.py | 20 ++++++- .../ops_converters/convolution_converter.py | 4 +- backends/nxp/tests/executorch_pipeline.py | 2 +- .../test_avg_pool2d_converter.py | 52 +++++++++++++++++++ 4 files changed, 73 insertions(+), 5 deletions(-) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/avg_pool_2d_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/avg_pool_2d_converter.py index 5654fdfab42..99ae0a30dbb 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/avg_pool_2d_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/avg_pool_2d_converter.py @@ -3,11 +3,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import numpy as np + from executorch.backends.nxp.backend.ir.converter.conversion import ( aten_translator, common, ) from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList +from executorch.backends.nxp.backend.ir.converter.conversion.translator import ( + tf_lite_type_to_numpy, +) from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, NodeConverter, @@ -62,9 +67,20 @@ def _convert_2d_avg_pool( ) if explicit_padding is not None: - # Need to prepend a 'Pad' operator, which adds 0s. But these will be included in the computation! + # Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). But these will + # be included in the computation! + input_quantization = t_op.tmp_inputs[0].quantization + pad_value = ( + None + if input_quantization is None + else np.array(input_quantization.zero_point[0]).astype( + tf_lite_type_to_numpy(t_op.tmp_inputs[0].type) + ) + ) ops.add_pre( - self.builder.create_pad_operator_before(t_op, 0, explicit_padding) + self.builder.create_pad_operator_before( + t_op, 0, explicit_padding, pad_value + ) ) return ops.flatten() diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py index 073c19f8871..821aeb31432 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py @@ -5,8 +5,6 @@ import numpy as np import torch -from torch.fx import Node -from torch.nn import Parameter from executorch.backends.nxp.backend.edge_helper import ( input_tensor, @@ -43,6 +41,8 @@ conv_2d_options, depthwise_conv_2d_options, ) +from torch.fx import Node +from torch.nn import Parameter class ConvolutionConverter(NodeConverter): diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index 7fc7cb7fb3c..3216bee7262 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -51,7 +51,7 @@ def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]): def to_quantized_edge_program( model: torch.nn.Module, - input_shapes: tuple[int] | list[tuple[int]], + input_shapes: tuple[int, ...] | list[tuple[int, ...]], operators_not_to_delegate: list[str] = None, target="imxrt700", neutron_converter_flavor="SDK_25_03", diff --git a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py index 8b6b63bb53f..bcdbd955c71 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py @@ -10,6 +10,12 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.backend.ir.converter.builder.model_builder import ( + ModelBuilder, +) +from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import ( + BuiltinOperator, +) from executorch.backends.nxp.tests.executorch_pipeline import ( to_edge_program, to_quantized_edge_program, @@ -156,3 +162,49 @@ def test_avg_pool_2d_quant_conversion(mocker, input_shape, padding, count_includ tflite_output_preprocess=ToNCHWPreprocess(), input_data=input_data, ) + + +def test_avg_pool_2d_quant_conversion__padded(mocker): + input_shape = (1, 8, 8, 8) + model = AvgPool2dModule(True, 1) + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + ops_spy = mocker.spy(ModelBuilder, "finish") + + # Run conversion + _ = to_quantized_edge_program(model, input_shape) + + # Capture the converter operators. + ops = ops_spy.spy_return.sub_graphs[0].operators.vector + + # Capture generated model + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + + # Capture converted program + exported_program: ExportedProgram = converter_spy.call_args.args[1] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8) + + convert_run_compare( + exported_program, + tflite_input_preprocess=ToNHWCPreprocess(), + tfl_model=tflite_flatbuffers_model, + tflite_output_preprocess=ToNCHWPreprocess(), + input_data=input_data, + ) + + assert len(ops) == 2 + assert ops[0].builtin_options.operator_type == BuiltinOperator.PADV2 + assert ops[1].builtin_options.operator_type == BuiltinOperator.AVERAGE_POOL_2D + + # Make sure the padding used the `zero-point`. + pad_value = ops[0].tmp_inputs[2].tmp_buffer.data.item() + assert ( + pad_value == ops[0].tmp_inputs[0].quantization.zero_point[0] + ) # `Pad` input zp. + assert ( + pad_value == ops[0].tmp_outputs[0].quantization.zero_point[0] + ) # `Pad` output zp. + assert ( + pad_value == ops[1].tmp_inputs[0].quantization.zero_point[0] + ) # `AvgPool` input zp.