Skip to content

Commit ca91e74

Browse files
authored
Arm Backend: Add support for select_scatter.default (#15972)
Add tests for the select_scatter.default op and decompose it into other operators. The decomposition only needs to occur in the INT pipeline since it is already supported in the FP pipeline, where it is decomposed during the export stage. Signed-off-by: Agrima Khare <agrima.khare@arm.com>
1 parent 4d7e3ee commit ca91e74

File tree

4 files changed

+319
-0
lines changed

4 files changed

+319
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from .decompose_round_pass import DecomposeRoundPass # noqa
6868
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
6969
from .decompose_select import DecomposeSelectPass # noqa
70+
from .decompose_select_scatter_pass import DecomposeSelectScatterPass # noqa
7071
from .decompose_sign_pass import DecomposeSignPass # noqa
7172
from .decompose_silu_pass import DecomposeSiluPass # noqa
7273
from .decompose_sinh_pass import DecomposeSinhPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
DecomposeRoundPass,
7171
DecomposeScaledDotProductAttentionPass,
7272
DecomposeSelectPass,
73+
DecomposeSelectScatterPass,
7374
DecomposeSignPass,
7475
DecomposeSiluPass,
7576
DecomposeSinhPass,
@@ -330,6 +331,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
330331
# Transformation passes (pre scalar -> tensor)
331332
self.add_passes(
332333
[
334+
DecomposeSelectScatterPass(),
333335
ConvertInt64ConstOpsToInt32Pass(),
334336
ConvertInt64OutputOpsToInt32Pass(),
335337
InsertInt32CastsAfterInt64PlaceholdersPass(),
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Set, Type
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes import ArmPass
11+
from executorch.backends.arm._passes.convert_int64_const_ops_to_int32 import (
12+
ConvertInt64ConstOpsToInt32Pass,
13+
)
14+
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
15+
ReplaceScalarWithTensorByProfilePass,
16+
)
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
from executorch.exir.pass_base import ExportPass
19+
20+
edge_scatter_ops = (exir_ops.edge.aten.select_scatter.default,)
21+
aten_scatter_ops = (torch.ops.aten.select_scatter.default,)
22+
23+
24+
def get_select_scatter_decomposition(op) -> tuple:
25+
if op in edge_scatter_ops:
26+
return (
27+
exir_ops.edge.aten.arange.start_step,
28+
exir_ops.edge.aten.eq.Scalar,
29+
exir_ops.edge.aten.where.self,
30+
exir_ops.edge.aten.expand_copy.default,
31+
exir_ops.edge.aten.unsqueeze_copy.default,
32+
exir_ops.edge.aten.view_copy.default,
33+
)
34+
if op in aten_scatter_ops:
35+
return (
36+
torch.ops.aten.arange.start_step,
37+
torch.ops.aten.eq.Scalar,
38+
torch.ops.aten.where.self,
39+
torch.ops.aten.expand_copy.default,
40+
torch.ops.aten.unsqueeze_copy.default,
41+
torch.ops.aten.view_copy.default,
42+
)
43+
44+
raise RuntimeError(f"Can't get select_scatter decomposition for op {op}")
45+
46+
47+
class DecomposeSelectScatterPass(ArmPass):
48+
"""select_scatter is decomposed into other ops during export, however this is only
49+
suppported for the fp profile and for the int profile we need to decompose it here.
50+
51+
The decomposition is as follows:
52+
- Build a boolean mask the size of x
53+
eq(view(arange(0, dim_size), mask_shape), index)
54+
- Broadcast source to x
55+
expand(unsqueeze(source, dim), shape)
56+
- Route the updated slice while keeping the untouched lanes
57+
where(mask, expanded_source, x)
58+
59+
This reflects the decomposition for the fp profile implemented in torch._refs
60+
"""
61+
62+
_passes_required_after: Set[Type[ExportPass]] = {
63+
ReplaceScalarWithTensorByProfilePass,
64+
ConvertInt64ConstOpsToInt32Pass,
65+
}
66+
67+
def call_operator(self, op, args, kwargs, meta):
68+
if op not in (edge_scatter_ops + aten_scatter_ops):
69+
return super().call_operator(op, args, kwargs, meta, updated=False)
70+
71+
(
72+
arange_op,
73+
eq_op,
74+
where_op,
75+
expand_op,
76+
unsqueeze_op,
77+
view_op,
78+
) = get_select_scatter_decomposition(op)
79+
80+
input_tensor = args[0]
81+
src_tensor = args[1]
82+
dim = int(args[2])
83+
index = int(args[3])
84+
85+
shape = input_tensor.data.size()
86+
rank = len(shape)
87+
dim = dim % rank if dim < 0 else dim
88+
dim_size = shape[dim]
89+
if index < 0:
90+
index = index + dim_size
91+
92+
mask_shape = [1] * rank
93+
mask_shape[dim] = -1
94+
95+
arange_node = super().call_operator(
96+
arange_op,
97+
(0, dim_size, 1),
98+
{},
99+
meta,
100+
updated=False,
101+
)
102+
103+
view_node = super().call_operator(
104+
view_op,
105+
(arange_node, mask_shape),
106+
{},
107+
meta,
108+
updated=False,
109+
)
110+
111+
mask_node = super().call_operator(
112+
eq_op,
113+
(view_node, index),
114+
{},
115+
meta,
116+
updated=False,
117+
)
118+
119+
unsqueeze_node = super().call_operator(
120+
unsqueeze_op,
121+
(src_tensor, dim),
122+
{},
123+
meta,
124+
updated=False,
125+
)
126+
127+
expand_node = super().call_operator(
128+
expand_op,
129+
(unsqueeze_node, shape),
130+
{},
131+
meta,
132+
updated=False,
133+
)
134+
135+
where_node = super().call_operator(
136+
where_op,
137+
(mask_node, expand_node, input_tensor),
138+
{},
139+
meta,
140+
updated=True,
141+
)
142+
143+
return where_node
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU85PipelineINT,
13+
OpNotSupportedPipeline,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
16+
VgfPipeline,
17+
)
18+
19+
test_data_suite = {
20+
"rank2_rand": lambda: (
21+
torch.randint(-30, 30, (5, 9), dtype=torch.float32),
22+
torch.randint(0, 9, (9,), dtype=torch.float32),
23+
0,
24+
2,
25+
),
26+
"rank2_zeros": lambda: (
27+
torch.rand((3, 2), dtype=torch.float32),
28+
torch.randint(0, 4, (2,), dtype=torch.float32),
29+
0,
30+
0,
31+
),
32+
"rank3_rand": lambda: (
33+
torch.rand((2, 4, 5), dtype=torch.float32),
34+
torch.randint(-5, 5, (2, 5), dtype=torch.float32),
35+
1,
36+
0,
37+
),
38+
"rank3_ones": lambda: (
39+
torch.ones((2, 3, 3), dtype=torch.float32),
40+
torch.rand((2, 3), dtype=torch.float32),
41+
2,
42+
2,
43+
),
44+
"rank4_rand": lambda: (
45+
torch.rand((1, 2, 4, 5), dtype=torch.float32),
46+
torch.rand((2, 4, 5), dtype=torch.float32),
47+
0,
48+
0,
49+
),
50+
"rank4_ones": lambda: (
51+
torch.ones((2, 3, 3, 2), dtype=torch.float32),
52+
torch.randint(-5, 5, (2, 3, 2), dtype=torch.float32),
53+
2,
54+
-1,
55+
),
56+
"rank5_ones": lambda: (
57+
torch.ones((3, 4, 20, 9, 5), dtype=torch.float32),
58+
torch.randn((3, 4, 20, 9), dtype=torch.float32),
59+
4,
60+
1,
61+
),
62+
"rank6_rand": lambda: (
63+
torch.rand((1, 2, 3, 4, 2, 1), dtype=torch.float32),
64+
torch.randn((2, 3, 4, 2, 1), dtype=torch.float32),
65+
0,
66+
0,
67+
),
68+
}
69+
70+
71+
class SelectScatter(torch.nn.Module):
72+
fp_aten_op = "torch.ops.aten.select_scatter.default"
73+
int_aten_ops = [
74+
"torch.ops.aten.arange.start_step",
75+
"torch.ops.aten.view_copy.default",
76+
"torch.ops.aten.unsqueeze_copy.default",
77+
"torch.ops.aten.expand_copy.default",
78+
"torch.ops.aten.where.self",
79+
"torch.ops.aten.eq.Tensor",
80+
]
81+
fp_exir_op = ["executorch_exir_dialects_edge__ops_aten_select_scatter_default"]
82+
int_exir_ops = [
83+
"executorch_exir_dialects_edge__ops_aten_eq_Tensor",
84+
"executorch_exir_dialects_edge__ops_aten_where_self",
85+
"executorch_exir_dialects_edge__ops_aten_arange_start_step",
86+
"executorch_exir_dialects_edge__ops_aten_view_copy_default",
87+
"executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default",
88+
"executorch_exir_dialects_edge__ops_aten_expand_copy_default",
89+
]
90+
u55_not_supported = {
91+
"executorch_exir_dialects_edge__ops_aten_eq_Tensor": 1,
92+
"executorch_exir_dialects_edge__ops_aten_where_self": 1,
93+
}
94+
95+
def forward(self, x: torch.Tensor, y: torch.Tensor, dim: int, index: int):
96+
return x.select_scatter(y, dim, index)
97+
98+
99+
input_t = Tuple[torch.Tensor, torch.Tensor, int, int]
100+
101+
102+
@common.parametrize("test_module", test_data_suite)
103+
def test_select_scatter_tosa_FP(test_module: input_t):
104+
pipeline = TosaPipelineFP[input_t](
105+
SelectScatter(),
106+
test_module(),
107+
aten_op=SelectScatter.fp_aten_op,
108+
exir_op=SelectScatter.fp_exir_op,
109+
)
110+
pipeline.run()
111+
112+
113+
@common.parametrize("test_module", test_data_suite)
114+
def test_select_scatter_tosa_INT(test_module: input_t):
115+
pipeline = TosaPipelineINT[input_t](
116+
SelectScatter(),
117+
test_module(),
118+
aten_op=SelectScatter.int_aten_ops,
119+
exir_op=SelectScatter.int_exir_ops,
120+
)
121+
pipeline.run()
122+
123+
124+
@common.parametrize("test_module", test_data_suite)
125+
def test_select_scatter_u55_INT(test_module: input_t):
126+
# select_scatter is not supported on U55
127+
pipeline = OpNotSupportedPipeline[input_t](
128+
SelectScatter(),
129+
test_module(),
130+
SelectScatter.u55_not_supported,
131+
quantize=True,
132+
u55_subset=True,
133+
n_expected_delegates=1,
134+
)
135+
pipeline.run()
136+
137+
138+
@common.XfailIfNoCorstone320
139+
@common.parametrize("test_module", test_data_suite)
140+
def test_select_scatter_u85_INT(test_module: input_t):
141+
pipeline = EthosU85PipelineINT[input_t](
142+
SelectScatter(),
143+
test_module(),
144+
aten_ops=SelectScatter.int_aten_ops,
145+
exir_ops=SelectScatter.int_exir_ops,
146+
)
147+
pipeline.run()
148+
149+
150+
@common.SkipIfNoModelConverter
151+
@common.parametrize("test_module", test_data_suite)
152+
def test_select_scatter_vgf_FP(test_module: input_t):
153+
pipeline = VgfPipeline[input_t](
154+
SelectScatter(),
155+
test_module(),
156+
aten_op=SelectScatter.fp_aten_op,
157+
exir_op=SelectScatter.fp_exir_op,
158+
tosa_version="TOSA-1.0+FP",
159+
)
160+
pipeline.run()
161+
162+
163+
@common.SkipIfNoModelConverter
164+
@common.parametrize("test_module", test_data_suite)
165+
def test_select_scatter_vgf_INT(test_module: input_t):
166+
pipeline = VgfPipeline[input_t](
167+
SelectScatter(),
168+
test_module(),
169+
aten_op=SelectScatter.int_aten_ops,
170+
exir_op=SelectScatter.int_exir_ops,
171+
tosa_version="TOSA-1.0+INT",
172+
)
173+
pipeline.run()

0 commit comments

Comments
 (0)