Skip to content

Commit ef3dcce

Browse files
NXP backend: Add support for the aten.cat operator. (#13505)
### Summary Add delegation of `aten.cat` to Neutron, and a `CustomDelegationOptions` class allowing delegation rules to be overridden. The CustomDelegationOption is introduced to allow force delegation of aten.cat operator, opportunistically, in cases when the constaint on number of channels cannot be determined automatically. ### Test plan Unit tests provided in `backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py`. cc @digantdesai @JakeStevens @robert-kalmar
1 parent 2813d07 commit ef3dcce

31 files changed

+735
-55
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from dataclasses import dataclass
8+
9+
10+
@dataclass
11+
class CustomDelegationOptions:
12+
"""The class allows the user to specify details which affect which nodes will be delegated."""
13+
14+
# Neutron requires the channel dimension to be multiple of `num_macs` for concatenation (cat op).
15+
# Due to different dim ordering in torch (channel_first) and Neutron IR (channel last), dim of the channel is
16+
# ambiguous. Cat converter will defensively require both possible dimension index for the channels to be multiple
17+
# of `num_macs`. The `force_delegate_cat` allows the user to turn off the defensive check if from the model design
18+
# it is known this constraint will be satisfied.
19+
force_delegate_cat: bool = False

backends/nxp/backend/edge_program_converter.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 NXP
1+
# Copyright 2024-2025 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -10,6 +10,9 @@
1010
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
1111
AtenModelBuilderDirector,
1212
)
13+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
14+
CustomDelegationOptions,
15+
)
1316
from torch.export import ExportedProgram
1417
from torch.export.graph_signature import InputKind
1518
from torch.fx import Node
@@ -28,6 +31,7 @@
2831
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
2932
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
3033
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
34+
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
3135
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
3236
exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, # noqa F405
3337
exir_ops.edge.aten.convolution.default: ConvolutionConverter, # noqa F405
@@ -49,24 +53,30 @@ class EdgeProgramToIRConverter:
4953
"""
5054

5155
_default_conversion_config = ConversionConfig()
56+
_default_delegation_options = CustomDelegationOptions()
5257

5358
def convert_program(
5459
self,
5560
edge_program: ExportedProgram,
5661
conversion_config=_default_conversion_config,
62+
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
5763
) -> (bytes, dict):
5864
"""
5965
Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes.
6066
6167
:param edge_program: Converter ExportedProgram.
6268
:param conversion_config: ConversionConfig instance.
69+
:param custom_delegation_options: Custom user options which affect node delegation.
6370
:return: TFLite flatbuffers as bytes.
6471
"""
6572
node_formats = NodeFormatInference(edge_program).identify_node_formats()
6673
parameters_mapping = self.map_inputs_to_parameters(edge_program)
6774

6875
cc = self.build_conversion_context(
69-
parameters_mapping, node_formats, conversion_config
76+
parameters_mapping,
77+
node_formats,
78+
conversion_config,
79+
custom_delegation_options,
7080
)
7181

7282
# Program conversion
@@ -162,6 +172,7 @@ def build_conversion_context(
162172
parameters_mapping: dict,
163173
node_formats: dict[Node, NodeFormat],
164174
conversion_config: ConversionConfig = _default_conversion_config,
175+
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
165176
) -> ConversionContext:
166177
tflite_builder = AtenModelBuilderDirector(
167178
3, "TFLite from EdgeProgram", conversion_config
@@ -171,7 +182,11 @@ def build_conversion_context(
171182
tflite_builder.build_empty_buffer()
172183

173184
context = ConversionContext(
174-
tflite_builder, conversion_config, parameters_mapping, node_formats
185+
tflite_builder,
186+
conversion_config,
187+
parameters_mapping,
188+
node_formats,
189+
custom_delegation_options,
175190
)
176191

177192
return context

backends/nxp/backend/ir/conversion_context.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
# Copyright 2024 NXP
1+
# Copyright 2024-2025 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from executorch.backends.nxp.backend.custom_delegation_options import (
7+
CustomDelegationOptions,
8+
)
69
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
710
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
811
AtenModelBuilderDirector,
@@ -17,13 +20,15 @@ class ConversionContext:
1720
conversion_config: ConversionConfig
1821
parameters_mapping: dict[str, Parameter]
1922
node_formats: dict[Node, NodeFormat]
23+
custom_delegation_options: CustomDelegationOptions
2024

2125
def __init__(
2226
self,
2327
tflite_builder: AtenModelBuilderDirector,
2428
conversion_config: ConversionConfig,
2529
parameters_mapping: dict,
2630
node_formats: dict[Node, NodeFormat],
31+
custom_delegation_options: CustomDelegationOptions,
2732
):
2833
"""
2934
Context with data related to current conversion.
@@ -35,3 +40,4 @@ def __init__(
3540
self.conversion_config = conversion_config
3641
self.parameters_mapping = parameters_mapping
3742
self.node_formats = node_formats
43+
self.custom_delegation_options = custom_delegation_options

backends/nxp/backend/ir/converter/node_converter.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
import torch
1010

11+
from executorch.backends.nxp.backend.custom_delegation_options import (
12+
CustomDelegationOptions,
13+
)
1114
from executorch.backends.nxp.backend.ir.conversion_context import ConversionContext
1215
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
1316
AtenModelBuilderDirector,
@@ -70,19 +73,25 @@ def convert(self, node: Node):
7073
@staticmethod
7174
@abstractmethod
7275
def _is_supported_in_IR(
73-
node: Node, parameters_mapping: dict[str, Parameter]
76+
node: Node,
77+
parameters_mapping: dict[str, Parameter],
78+
custom_delegation_options: CustomDelegationOptions,
7479
) -> bool:
7580
"""Check if the `node` can be converted to the intermediate representation.
7681
Classes which implement conversion for individual operators must overwrite this method.
7782
7883
:param node: torch.Node to check.
7984
:param parameters_mapping: Dictionary mapping tensor names to their static data (if they have it).
85+
:param custom_delegation_options: Custom options which affect delegation.
8086
"""
8187
pass
8288

8389
@staticmethod
8490
def _is_supported_on_target(
85-
node: Node, target: Target, parameters_mapping: dict[str, Parameter]
91+
node: Node,
92+
target: Target,
93+
parameters_mapping: dict[str, Parameter],
94+
custom_delegation_options: CustomDelegationOptions,
8695
) -> bool:
8796
"""Check if the node is supported on the target platform.
8897
Child classes should overwrite this method to implement specific target checks. The default implementation
@@ -91,22 +100,30 @@ def _is_supported_on_target(
91100
:param node: The node (edge operator) to check.
92101
:param target: Value of the `Target` enum representing the target platform to check for.
93102
:param parameters_mapping: Dictionary mapping tensor names to their static data (if they have it).
103+
:param custom_delegation_options: Custom options which affect delegation.
94104
"""
95105
return target == Target.RT700
96106

97107
@classmethod
98108
def is_supported(
99-
cls, node: Node, target: Target, parameters_mapping: dict[str, Parameter]
109+
cls,
110+
node: Node,
111+
target: Target,
112+
parameters_mapping: dict[str, Parameter],
113+
custom_delegation_options: CustomDelegationOptions,
100114
) -> bool:
101115
"""Check if the given `node` is supported in the IR and on the given `target` platform.
102116
103117
:param node: torch.Node to check.
104118
:param target: Value of the `Target` enum representing the target platform to check for.
105119
:param parameters_mapping: Dict mapping tensor names to their data.
120+
:param custom_delegation_options: Custom user options which affect node delegation.
106121
"""
107122
return cls._is_supported_in_IR(
108-
node, parameters_mapping
109-
) and cls._is_supported_on_target(node, target, parameters_mapping)
123+
node, parameters_mapping, custom_delegation_options
124+
) and cls._is_supported_on_target(
125+
node, target, parameters_mapping, custom_delegation_options
126+
)
110127

111128
@staticmethod
112129
def _has_shared_q_params_if_quantized(node: Node) -> bool:
@@ -145,7 +162,11 @@ def assert_convertible(self, node):
145162
"""Assert that the call `_is_supported_in_IR()` returns `True`. Otherwise, raise an exception and print an
146163
error message.
147164
"""
148-
assert self._is_supported_in_IR(node, self.context.parameters_mapping), (
165+
assert self._is_supported_in_IR(
166+
node,
167+
self.context.parameters_mapping,
168+
self.context.custom_delegation_options,
169+
), (
149170
f"Node `{node}` is not convertible to the intermediate representation. "
150171
"There is an error in the partitioner."
151172
)
@@ -169,7 +190,15 @@ def _create_tflite_op_with_io_tensors(self, node: Node) -> tflite_model.Operator
169190

170191
# Initialize node's inputs
171192
t_operator.inputs = tflite_model.OperatorInputs()
172-
input_nodes = [arg for arg in node.args if isinstance(arg, Node)]
193+
194+
input_nodes = []
195+
for arg in node.args:
196+
match arg:
197+
case Node():
198+
input_nodes.append(arg)
199+
case list() if all(isinstance(node_, Node) for node_ in arg):
200+
input_nodes.extend(arg)
201+
173202
for ancestor_node in input_nodes:
174203
assert self.context.tflite_builder.tensor_exists(ancestor_node.name)
175204
t_operator.tmp_inputs.append(

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.avg_pool_2d_converter import (
1414
AvgPool2dConverter,
1515
)
16+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.cat_converter import (
17+
CatConverter,
18+
)
1619
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.clone_converter import (
1720
CloneConverter,
1821
)
@@ -58,6 +61,7 @@
5861

5962
__all__ = [
6063
"AddMMConverter",
64+
"CatConverter",
6165
"ConvolutionConverter",
6266
"MMConverter",
6367
"PermuteCopyConverter",

backends/nxp/backend/ir/converter/node_converters/ops_converters/abs_converter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter
7+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
8+
CustomDelegationOptions,
9+
NodeConverter,
10+
)
811
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
912
abs_options,
1013
)
@@ -16,7 +19,9 @@ class AbsConverter(NodeConverter):
1619

1720
@staticmethod
1821
def _is_supported_in_IR(
19-
node: Node, parameters_mapping: dict[str, Parameter]
22+
node: Node,
23+
parameters_mapping: dict[str, Parameter],
24+
custom_delegation_options: CustomDelegationOptions,
2025
) -> bool:
2126
return True
2227

backends/nxp/backend/ir/converter/node_converters/ops_converters/adaptive_avg_pool_2d_converter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
import executorch.backends.nxp.backend.ir.lib.tflite.Padding as tflPadding
77
from executorch.backends.nxp.backend.ir.converter.conversion import common
8-
from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter
8+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
9+
CustomDelegationOptions,
10+
NodeConverter,
11+
)
912
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
1013
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
1114
average_pool_2d_options,
@@ -19,7 +22,9 @@ class AdaptiveAvgPool2dConverter(NodeConverter):
1922

2023
@staticmethod
2124
def _is_supported_in_IR(
22-
node: Node, parameters_mapping: dict[str, Parameter]
25+
node: Node,
26+
parameters_mapping: dict[str, Parameter],
27+
custom_delegation_options: CustomDelegationOptions,
2328
) -> bool:
2429
input_size = node.args[0].meta["val"].shape
2530
output_size = node.args[1]

backends/nxp/backend/ir/converter/node_converters/ops_converters/add_tensor_converter.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
node_uses_shape_broadcasting,
88
)
99
from executorch.backends.nxp.backend.ir.converter.node_converter import (
10+
CustomDelegationOptions,
1011
NodeConverter,
1112
Target,
1213
)
@@ -20,7 +21,10 @@
2021
class AddTensorConverter(NodeConverter):
2122
@staticmethod
2223
def _is_supported_on_target(
23-
node: Node, target: Target, parameters_mapping: dict[str, Parameter]
24+
node: Node,
25+
target: Target,
26+
parameters_mapping: dict[str, Parameter],
27+
custom_delegation_options: CustomDelegationOptions,
2428
) -> bool:
2529
match target:
2630
case Target.RT700:
@@ -35,7 +39,9 @@ def _is_supported_on_target(
3539

3640
@staticmethod
3741
def _is_supported_in_IR(
38-
node: Node, parameters_mapping: dict[str, Parameter]
42+
node: Node,
43+
parameters_mapping: dict[str, Parameter],
44+
custom_delegation_options: CustomDelegationOptions,
3945
) -> bool:
4046
if len(node.args) != 2:
4147
return False

backends/nxp/backend/ir/converter/node_converters/ops_converters/addmm_converter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
from executorch.backends.nxp.backend.edge_helper import input_rank
77
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
8-
from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter
8+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
9+
CustomDelegationOptions,
10+
NodeConverter,
11+
)
912
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
1013
fully_connected_options,
1114
)
@@ -18,7 +21,9 @@ class AddMMConverter(NodeConverter):
1821

1922
@staticmethod
2023
def _is_supported_in_IR(
21-
node: Node, parameters_mapping: dict[str, Parameter]
24+
node: Node,
25+
parameters_mapping: dict[str, Parameter],
26+
custom_delegation_options: CustomDelegationOptions,
2227
) -> bool:
2328
if len(node.all_input_nodes) != 3:
2429
return False

backends/nxp/backend/ir/converter/node_converters/ops_converters/avg_pool_2d_converter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
common,
99
)
1010
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
11-
from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter
11+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
12+
CustomDelegationOptions,
13+
NodeConverter,
14+
)
1215
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
1316
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
1417
average_pool_2d_options,
@@ -21,7 +24,9 @@ class AvgPool2dConverter(NodeConverter):
2124

2225
@staticmethod
2326
def _is_supported_in_IR(
24-
node: Node, parameters_mapping: dict[str, Parameter]
27+
node: Node,
28+
parameters_mapping: dict[str, Parameter],
29+
custom_delegation_options: CustomDelegationOptions,
2530
) -> bool:
2631
n_args = len(node.args)
2732

0 commit comments

Comments
 (0)