diff --git a/docs/api/ir_convenience.md b/docs/api/ir_convenience.md index 12dc1e0f..28ab27f9 100644 --- a/docs/api/ir_convenience.md +++ b/docs/api/ir_convenience.md @@ -13,4 +13,5 @@ .. autofunction:: get_const_tensor .. autofunction:: replace_all_uses_with .. autofunction:: replace_nodes_and_values +.. autofunction:: insert_nodes_in_value ``` diff --git a/src/onnx_ir/_convenience/__init__.py b/src/onnx_ir/_convenience/__init__.py index 7842859d..da8836d6 100644 --- a/src/onnx_ir/_convenience/__init__.py +++ b/src/onnx_ir/_convenience/__init__.py @@ -15,6 +15,7 @@ "create_value_mapping", "replace_nodes_and_values", "get_const_tensor", + "insert_nodes_in_value", ] import logging @@ -376,6 +377,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]: return values +def _update_graph_or_function_outputs( + graph_or_function: _core.Graph | _core.Function, + old_values: Sequence[_core.Value], + new_values: Sequence[_core.Value], +): + """Update graph/function outputs.""" + replacement_mapping = dict(zip(old_values, new_values)) + for idx, graph_or_function_output in enumerate(graph_or_function.outputs): + if graph_or_function_output in replacement_mapping: + graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] + + def replace_nodes_and_values( graph_or_function: _core.Graph | _core.Function, /, @@ -407,10 +420,7 @@ def replace_nodes_and_values( # Reconnect the users of the deleted values to use the new values replace_all_uses_with(old_values, new_values) # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(old_values, new_values)) - for idx, graph_or_function_output in enumerate(graph_or_function.outputs): - if graph_or_function_output in replacement_mapping: - graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] + _update_graph_or_function_outputs(graph_or_function, old_values, new_values) # insert new nodes after the index node graph_or_function.insert_after(insertion_point, new_nodes) @@ -518,3 +528,117 @@ def get_const_tensor( ) value.type = new_value_type return tensor + + +def _find_inputs_outputs( + nodes: Sequence[_core.Node], +) -> tuple[tuple[_core.Value | None, ...], tuple[_core.Value, ...]]: + """Find the values that are considered as inputs and outputs in a sequence of nodes.""" + # Search the unique inputs/outputs in new_nodes, keeping the order. + all_inputs = dict.fromkeys(sum((node.inputs for node in nodes), ())) # type: ignore[type-var] + all_outputs = dict.fromkeys(sum((node.outputs for node in nodes), ())) # type: ignore[type-var] + # A value is considered as input if it is not any output. + inputs = tuple(val for val in all_inputs if val not in all_outputs) + # A value is considered as output if it is not any input. + outputs = tuple(val for val in all_outputs if val not in all_inputs) + return inputs, outputs + + +def insert_nodes_in_value( + values: _core.Value | Sequence[_core.Value], new_nodes: Sequence[_core.Node] +) -> None: + """Inserts a sequence of nodes into the provided value(s). + + This allows to insert a list of LINKED nodes (over the same context) at + a specific point in the graph. + + For example, suppose we have the following graph:: + + input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output + + We want to insert [node_M, node_N] at B value:: + + >>> import onnx_ir as ir + >>> input = ir.Input("input") + >>> node_A = ir.node("op_A", [input]) + >>> B = ir.Value(name="B") + >>> node_B = ir.node("op_B", node_A.outputs, outputs=[B]) + >>> node_C = ir.node("op_C", node_B.outputs) + >>> # Create a new sequence to insert + >>> input_2 = ir.Input("input_2") + >>> node_M = ir.node("op_M", [input_2]) + >>> node_N = ir.node("op_N", node_M.outputs) + >>> # Insert nodes in B + >>> insert_nodes_in_value(node_B.outputs, [node_M, node_N]) + >>> len(node_B.outputs) + 1 + >>> node_B.outputs[0].consumers()[0].op_type + 'op_M' + >>> len(node_C.inputs) + 1 + >>> node_C.inputs[0].producer().op_type + 'op_N' + >>> node_C.inputs[0].name + 'B' + + When values is a sequence, the set of nodes must have the same number + of inputs and outputs, then they are zipped into pairs: first value is + replaced with the first input/output, and so on. + + Args: + values: The value(s) where to insert the nodes. + new_nodes: The nodes to insert in the graph. + """ + if not isinstance(values, Sequence): + values = (values,) + + # Search the unique inputs/outputs in new_nodes, keeping the order. + inputs, outputs = _find_inputs_outputs(new_nodes) + + # Sanity check. + if len(values) != len(inputs): + raise ValueError( + f"The number of values and inputs ({inputs}) in new_nodes must match." + ) + if len(values) != len(outputs): + raise ValueError( + f"The number of values and outputs ({outputs}) in new_nodes must match." + ) + + # Propagate relevant info. + for val, in_val, out_val in zip(values, inputs, outputs): + # Propagate relevant info from value to out_value. + # TODO(Rama): Perhaps this should be a separate utility function. + out_val.type = val.type + out_val.shape = val.shape + out_val.name = val.name + # Propagate relevant info from value to in_value. + # TODO(Rama): Perhaps this should be a separate utility function. + in_val.type = val.type + in_val.shape = val.shape + # Rename each value, following each input. + val.name = in_val.name + + # Insert the new nodes in two steps: + # 1. Reconnect the users of values to the outputs + replace_all_uses_with(values, outputs) + # 2. Reconnect the users of inputs to values + replace_all_uses_with(inputs, values) + + # Update graph if there is one: + if (graph := values[-1].graph) is not None: + # Update graph/function outputs if the node generates output + _update_graph_or_function_outputs(graph, values, outputs) + + # Insert new nodes if there is a graph + # Note nodes are inserted at the beginning when values are the graph inputs. + target_node = values[0].producer() + for v in values[1:]: + if (new_target_node := v.producer()) is None: + continue + if target_node is None or graph.index(target_node) < graph.index(new_target_node): + target_node = new_target_node + if target_node is None: + graph.insert_before(graph[0], new_nodes) + else: + graph.insert_after(target_node, new_nodes) diff --git a/src/onnx_ir/_convenience/_init_test.py b/src/onnx_ir/_convenience/_init_test.py new file mode 100644 index 00000000..b86162ab --- /dev/null +++ b/src/onnx_ir/_convenience/_init_test.py @@ -0,0 +1,146 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the _convenience module.""" + +import unittest + +import onnx + +import onnx_ir as ir +from onnx_ir._convenience import insert_nodes_in_value + + +def _create_model(model_text: str) -> ir.Model: + model = onnx.parser.parse_model(model_text) + return ir.serde.deserialize_model(model) + + +class ConvenienceTest(unittest.TestCase): + def test_insert_nodes_in_value(self): + # Main graph + input = ir.Input("input") + node_A = ir.node("op_A", [input]) + node_B = ir.node("op_B", node_A.outputs, outputs=[ir.Value(name="B")]) + node_C = ir.node("op_C", node_B.outputs) + + # New sequence to insert + input_2 = ir.Input("input_2") + node_M = ir.node("op_M", [input_2]) + node_N = ir.node("op_N", node_M.outputs) + + # Insert nodes in B + insert_nodes_in_value(node_B.outputs[0], [node_M, node_N]) + self.assertEqual(len(node_B.outputs), 1) + self.assertEqual(node_B.outputs[0].consumers()[0].op_type, "op_M") + self.assertEqual(len(node_C.inputs), 1) + self.assertEqual(node_C.inputs[0].producer().op_type, "op_N") + self.assertEqual(node_C.inputs[0].name, "B") + + def test_insert_nodes_in_value_in_graph(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + a, b = SplitNode(x) + z = MergeNode(a, b, two) + } + """ + ) + + # Sequence to insert. + # Note inputs = [i1, i2] and outputs = [b.outputs[1], c.outputs[0]]. + i1, i2 = ir.Input("i1"), ir.Input("i2") + a = ir.node("op_1", [i1, i2]) + b = ir.node("op_2", [a.outputs[0], i1], num_outputs=2) + c = ir.node("op_3", [i2, b.outputs[0]]) + + # Insert nodes in SplitNode.outputs + target_node = ir_model.graph[1] + insert_nodes_in_value(target_node.outputs, [a, b, c]) + + # Check target_node outputs have been renamed + new_i1, new_i2 = target_node.outputs + self.assertEqual(new_i1.name, "i1") + self.assertEqual(new_i2.name, "i2") + + # Check i1 and i2 have new users + self.assertEqual(tuple(node.op_type for node in new_i1.consumers()), ("op_1", "op_2")) + self.assertEqual(tuple(node.op_type for node in new_i2.consumers()), ("op_1", "op_3")) + + # Check outputs have been correctly renamed as previous values + self.assertEqual(b.outputs[1].name, "a") + self.assertEqual(c.outputs[0].name, "b") + + # Check nodes have been inserted in the graph + self.assertEqual(len(ir_model.graph), 6) + + def test_insert_nodes_in_input(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + z = Add(x, two) + } + """ + ) + + # Sequence to insert. + x = ir.Input("new_x") + node = ir.node("Mul", [x, x]) + + # Insert nodes in graph.inputs + insert_nodes_in_value(ir_model.graph[1].inputs[0], [node]) + self.assertEqual(node.outputs[0].name, "x") + + # Check input has been renamed + self.assertEqual(ir_model.graph.inputs[0].name, "new_x") + + # Finally, check new graph is valid + proto = ir.to_proto(ir_model) + onnx.checker.check_model(proto, full_check=True) + + def test_insert_nodes_in_output(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + z = Add(x, two) + } + """ + ) + + # Sequence to insert. + x = ir.Input("new_z") + node = ir.node("Mul", [x, x]) + + # Insert nodes in graph.inputs + insert_nodes_in_value(ir_model.graph.outputs[0], [node]) + self.assertEqual(ir_model.graph[1].outputs[0].name, "new_z") + + # Check output name is preserved + self.assertEqual(ir_model.graph.outputs[0].name, "z") + + def test_value_error_for_wrong_number_of_points(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + a, b = SplitNode(x) + z = MergeNode(a, b, two) + } + """ + ) + node = ir.node("op_M", [ir.Input("new_x"), ir.Input("new_y")]) + with self.assertRaisesRegex(ValueError, "The number of values and inputs"): + insert_nodes_in_value(ir_model.graph[0].outputs, [node]) + + with self.assertRaisesRegex(ValueError, "The number of values and outputs"): + insert_nodes_in_value(ir_model.graph[1].outputs, [node]) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/onnx_ir/convenience.py b/src/onnx_ir/convenience.py index f80d6b5c..c19565f6 100644 --- a/src/onnx_ir/convenience.py +++ b/src/onnx_ir/convenience.py @@ -9,6 +9,7 @@ "convert_attributes", "create_value_mapping", "get_const_tensor", + "insert_nodes_in_value", "replace_all_uses_with", "replace_nodes_and_values", ] @@ -18,6 +19,7 @@ convert_attributes, create_value_mapping, get_const_tensor, + insert_nodes_in_value, replace_all_uses_with, replace_nodes_and_values, )