|
17 | 17 |
|
18 | 18 | import logging |
19 | 19 | import tempfile |
20 | | -from collections import deque |
21 | 20 | from itertools import count |
22 | | -from typing import cast, Dict, final, List, Set |
| 21 | +from typing import cast, Dict, final, List |
23 | 22 |
|
24 | 23 | import tosa_serializer as ts |
25 | 24 | from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec |
|
43 | 42 |
|
44 | 43 |
|
45 | 44 | def _annotate_external_ids(ep_graph: Graph) -> Dict[str, int]: |
46 | | - """Assign deterministic output IDs to nodes reachable from graph outputs. |
| 45 | + """Assign deterministic output IDs to leaf outputs. |
| 46 | +
|
| 47 | + Flattens the output structure and assigns the external ID |
| 48 | + based on the leaf position in the exported output tuple/list. |
47 | 49 |
|
48 | 50 | Args: |
49 | 51 | ep_graph (Graph): FX graph produced by export preprocessing. |
50 | 52 |
|
51 | 53 | Returns: |
52 | | - dict[str, int]: Mapping from node name to external output index. |
53 | | -
|
| 54 | + dict[str, int]: Mapping from *leaf output node name* to external output index. |
54 | 55 | """ |
55 | 56 | node2external_id = {} |
56 | 57 |
|
57 | | - def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]): |
58 | | - """Walk producer graph from ``start_nodes`` and record external IDs.""" |
59 | | - q = deque(start_nodes) |
60 | | - while q: |
61 | | - n = q.popleft() |
62 | | - if n in seen: |
63 | | - continue |
64 | | - seen.add(n) |
65 | | - node2external_id[n.name] = idx |
66 | | - # Walk backwards so we touch every producer |
67 | | - q.extend(n.all_input_nodes) |
| 58 | + def _collect_leaves(arg, nodes): |
| 59 | + # Collect only FX Nodes that are actual outputs |
| 60 | + # (ignore ints/None/etc inside structured outputs). |
| 61 | + if isinstance(arg, Node): |
| 62 | + nodes.append(arg) |
| 63 | + elif isinstance(arg, (list, tuple)): |
| 64 | + for a in arg: |
| 65 | + _collect_leaves(a, nodes) |
68 | 66 |
|
69 | 67 | out = ep_graph.output_node() |
70 | | - # First argument of output node is tuple of outputs |
71 | | - output_list = cast(tuple, out.args[0]) |
72 | | - seen: Set[Node] = set() |
73 | | - for idx, val in enumerate(output_list): |
74 | | - bfs_mark([val], idx, seen) |
| 68 | + out_leaves: list[Node] = [] |
| 69 | + # First argument of output is the structured container (tuple/list) of outputs |
| 70 | + _collect_leaves(out.args[0], out_leaves) |
| 71 | + |
| 72 | + # Map each output leaf's name to its position |
| 73 | + node2external_id = {leaf.name: idx for idx, leaf in enumerate(out_leaves)} |
| 74 | + |
75 | 75 | return node2external_id |
76 | 76 |
|
77 | 77 |
|
|
0 commit comments