Skip to content

Commit c1113d3

Browse files
mansnilszingo
andauthored
Arm backend: Update output reorder workaround (#15981)
Updates _annotate_external_ids() to only assign ids to actual outputs in current order instead of traversing all producers. This makes the mapping more robust to structured (tuple/namedtuple) outputs and avoids interference from lifted constants or shared producers. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Co-authored-by: Zingo Andersen <zingo.andersen@arm.com>
1 parent f255a99 commit c1113d3

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

backends/arm/tosa/backend.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
import logging
1919
import tempfile
20-
from collections import deque
2120
from itertools import count
22-
from typing import cast, Dict, final, List, Set
21+
from typing import cast, Dict, final, List
2322

2423
import tosa_serializer as ts
2524
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
@@ -43,35 +42,36 @@
4342

4443

4544
def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]:
46-
"""Assign deterministic output IDs to nodes reachable from graph outputs.
45+
"""Assign deterministic output IDs to leaf outputs.
46+
47+
Flattens the output structure and assigns the external ID
48+
based on the leaf position in the exported output tuple/list.
4749
4850
Args:
4951
ep_graph (Graph): FX graph produced by export preprocessing.
5052
5153
Returns:
52-
dict[str, int]: Mapping from node name to external output index.
53-
54+
dict[str, int]: Mapping from *leaf output node name* to external output index.
5455
"""
5556
node2external_id = {}
5657

57-
def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
58-
"""Walk producer graph from ``start_nodes`` and record external IDs."""
59-
q = deque(start_nodes)
60-
while q:
61-
n = q.popleft()
62-
if n in seen:
63-
continue
64-
seen.add(n)
65-
node2external_id[n.name] = idx
66-
# Walk backwards so we touch every producer
67-
q.extend(n.all_input_nodes)
58+
def _collect_leaves(arg, nodes):
59+
# Collect only FX Nodes that are actual outputs
60+
# (ignore ints/None/etc inside structured outputs).
61+
if isinstance(arg, Node):
62+
nodes.append(arg)
63+
elif isinstance(arg, (list, tuple)):
64+
for a in arg:
65+
_collect_leaves(a, nodes)
6866

6967
out = ep_graph.output_node()
70-
# First argument of output node is tuple of outputs
71-
output_list = cast(tuple, out.args[0])
72-
seen: Set[Node] = set()
73-
for idx, val in enumerate(output_list):
74-
bfs_mark([val], idx, seen)
68+
out_leaves: list[Node] = []
69+
# First argument of output is the structured container (tuple/list) of outputs
70+
_collect_leaves(out.args[0], out_leaves)
71+
72+
# Map each output leaf's name to its position
73+
node2external_id = {leaf.name: idx for idx, leaf in enumerate(out_leaves)}
74+
7575
return node2external_id
7676

7777

0 commit comments

Comments
 (0)