Skip to content

Commit def4c64

Browse files
NXP backend: Add implementation of Tanh operator converter (#13510)
### Summary Add delegation support for the `aten.tanh` operator. ### Test plan Unit tests provided in `backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py`. cc @digantdesai @JakeStevens @robert-kalmar
1 parent ef3dcce commit def4c64

File tree

15 files changed

+227
-72
lines changed

15 files changed

+227
-72
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
4343
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
4444
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
45+
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
4546
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
4647
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
4748
}

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.softmax_converter import (
5656
SoftmaxConverter,
5757
)
58+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.tanh_converter import (
59+
TanhConverter,
60+
)
5861
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.view_copy_converter import (
5962
ViewCopyConverter,
6063
)
@@ -80,4 +83,5 @@
8083
"AdaptiveAvgPool2dConverter",
8184
"HardTanhConverter",
8285
"SigmoidConverter",
86+
"TanhConverter",
8387
]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.nxp.backend.ir.converter.node_converter import NodeConverter
7+
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import (
8+
BuiltinOperator,
9+
)
10+
from torch.fx import Node
11+
from torch.nn import Parameter
12+
13+
14+
class TanhConverter(NodeConverter):
15+
16+
@staticmethod
17+
def _is_supported_in_IR(
18+
node: Node,
19+
parameters_mapping: dict[str, Parameter],
20+
) -> bool:
21+
return True
22+
23+
def convert(self, node: Node):
24+
self.assert_convertible(node)
25+
26+
t_op = self._create_tflite_op_with_io_tensors(node)
27+
t_op.opcode_index = self.builder.op_code_index_for_op_type(BuiltinOperator.TANH)
28+
29+
self.builder.append_operators([t_op])

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
206206
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
207207
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
208208
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
209+
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
209210
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
210211
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
211212
}

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
SharedSpecPattern,
3636
SigmoidPattern,
3737
SoftMaxPattern,
38+
TanhInPlacePattern,
39+
TanhPattern,
3840
ViewPattern,
3941
)
4042
from executorch.backends.nxp.quantizer.utils import (
@@ -223,6 +225,8 @@ def __init__(self):
223225
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
224226
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
225227
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
228+
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
229+
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
226230
NeutronAtenQuantizer(ViewPattern(), static_qconfig),
227231
]
228232
)

backends/nxp/quantizer/patterns.py

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,35 @@ def get_anchors(
106106
)
107107

108108

109+
def get_anchors_for_fixed_quant_specs(
110+
fused_partition: list[fx.GraphModule],
111+
scale: float,
112+
zero_point: int,
113+
quant_min: int = -128,
114+
quant_max: int = 127,
115+
) -> PartitionAnchors:
116+
node = fused_partition[0].nodes[-1]
117+
assert len(fused_partition[0].input_nodes) == 1
118+
119+
qspec = FixedQParamsQuantizationSpec(
120+
dtype=torch.int8,
121+
scale=scale,
122+
zero_point=zero_point,
123+
quant_min=quant_min,
124+
quant_max=quant_max,
125+
qscheme=torch.per_tensor_affine,
126+
)
127+
128+
return PartitionAnchors(
129+
inputs=[(node, 0)],
130+
weights=[],
131+
biases=[],
132+
output=[
133+
(node, qspec),
134+
],
135+
)
136+
137+
109138
class AbsPattern(SharedSpecPattern):
110139
"""
111140
Quantizer for Abs operator.
@@ -479,31 +508,6 @@ def partition_types(self):
479508
return [torch.ops.aten.view.default]
480509

481510

482-
def get_anchors_for_softmax_like_operators(
483-
fused_partition: List[fx.GraphModule],
484-
) -> PartitionAnchors:
485-
node = fused_partition[0].nodes[-1]
486-
assert len(fused_partition[0].input_nodes) == 1
487-
488-
qspec = FixedQParamsQuantizationSpec(
489-
dtype=torch.int8,
490-
scale=1.0 / 256.0,
491-
zero_point=-128,
492-
quant_min=-128,
493-
quant_max=127,
494-
qscheme=torch.per_tensor_affine,
495-
)
496-
497-
return PartitionAnchors(
498-
inputs=[(node, 0)],
499-
weights=[],
500-
biases=[],
501-
output=[
502-
(node, qspec),
503-
],
504-
)
505-
506-
507511
class SoftMaxPattern(QuantizationPattern):
508512
"""
509513
Quantizer for Softmax operator.
@@ -515,9 +519,47 @@ def partition_types(self) -> List[OpOverload]:
515519
return [torch.ops.aten.softmax.int]
516520

517521
def get_anchors(
518-
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
522+
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
519523
) -> PartitionAnchors:
520-
return get_anchors_for_softmax_like_operators(fused_partition)
524+
return get_anchors_for_fixed_quant_specs(
525+
fused_partition, scale=1.0 / 256.0, zero_point=-128
526+
)
527+
528+
529+
class TanhPattern(QuantizationPattern):
530+
"""
531+
Quantizer for Tanh operator.
532+
533+
The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
534+
"""
535+
536+
def partition_types(self):
537+
return [torch.ops.aten.tanh.default]
538+
539+
def get_anchors(
540+
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
541+
) -> PartitionAnchors:
542+
return get_anchors_for_fixed_quant_specs(
543+
fused_partition, scale=1.0 / 128.0, zero_point=0
544+
)
545+
546+
547+
class TanhInPlacePattern(QuantizationPattern):
548+
"""
549+
Quantizer for inplace version of Tanh operator (torch.tanh_).
550+
551+
The quantization of Tanh output is fixed to scale 1/128, zero point 0, dtype int8.
552+
"""
553+
554+
def partition_types(self):
555+
return [torch.ops.aten.tanh_.default]
556+
557+
def get_anchors(
558+
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
559+
) -> PartitionAnchors:
560+
return get_anchors_for_fixed_quant_specs(
561+
fused_partition, scale=1.0 / 128.0, zero_point=0
562+
)
521563

522564

523565
class SigmoidPattern(QuantizationPattern):
@@ -533,4 +575,6 @@ def partition_types(self) -> List[OpOverload]:
533575
def get_anchors(
534576
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
535577
) -> PartitionAnchors:
536-
return get_anchors_for_softmax_like_operators(fused_partition)
578+
return get_anchors_for_fixed_quant_specs(
579+
fused_partition, scale=1.0 / 256.0, zero_point=-128
580+
)

backends/nxp/run_unittests.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ cd $EXECUTORCH_DIR
1212

1313
# '-c /dev/null' is used to ignore root level pytest.ini.
1414
pytest -c /dev/null backends/nxp/tests/
15+
16+
python -m unittest discover -s backends/nxp/tests/ -v

backends/nxp/tests/ir/__init__.py

Whitespace-only changes.

backends/nxp/tests/ir/converter/__init__.py

Whitespace-only changes.

backends/nxp/tests/ir/converter/node_converter/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)