diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index f1276f13..d71997c7 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1720,6 +1720,56 @@ def inputs(self, _: Any) -> None: "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead." ) + @property + def outputs(self) -> Sequence[Value]: + """The output values of the node. + + The outputs are immutable. To change the outputs, create a new node and + replace the inputs of the using nodes of this node's outputs by calling + :meth:`replace_input_with` on the using nodes of this node's outputs. + """ + return self._outputs + + @outputs.setter + def outputs(self, _: Sequence[Value]) -> None: + raise AttributeError("outputs is immutable. Please create a new node instead.") + + def i(self, index: int = 0) -> Value | None: + """Get the input value at the given index. + + This is a convenience method that is equivalent to ``node.inputs[index]``. + + The following is equivalent:: + + node.inputs[0] == node.i(0) == node.i() # Default index is 0 + node.inputs[index] == node.i(index) + + Returns: + The input value at the given index. + + Raises: + IndexError: If the index is out of range. + """ + return self.inputs[index] + + def o(self, index: int = 0) -> Value: + """Get the output value at the given index. + + This is a convenience method that is equivalent to ``node.outputs[index]``. + + The following is equivalent:: + + node.outputs[0] == node.o(0) == node.o() # Default index is 0 + node.outputs[index] == node.o(index) + + Returns: + The output value at the given index. + + Raises: + IndexError: If the index is out of range. + """ + return self.outputs[index] + def predecessors(self) -> Sequence[Node]: """Return the predecessor nodes of the node, deduplicated, in a deterministic order.""" # Use the ordered nature of a dictionary to deduplicate the nodes @@ -1790,20 +1840,6 @@ def append(self, /, nodes: Node | Iterable[Node]) -> None: raise ValueError("The node to append to does not belong to any graph.") self._graph.insert_after(self, nodes) - @property - def outputs(self) -> Sequence[Value]: - """The output values of the node. - - The outputs are immutable. To change the outputs, create a new node and - replace the inputs of the using nodes of this node's outputs by calling - :meth:`replace_input_with` on the using nodes of this node's outputs. - """ - return self._outputs - - @outputs.setter - def outputs(self, _: Sequence[Value]) -> None: - raise AttributeError("outputs is immutable. Please create a new node instead.") - @property def attributes(self) -> _graph_containers.Attributes: """The attributes of the node as ``dict[str, Attr]`` with additional access methods. diff --git a/src/onnx_ir/external_data_test.py b/src/onnx_ir/external_data_test.py index f778b513..ffcd5a7a 100644 --- a/src/onnx_ir/external_data_test.py +++ b/src/onnx_ir/external_data_test.py @@ -173,13 +173,13 @@ def _simple_model(self) -> ir.Model: node_1 = ir.Node( "", "Op_1", - inputs=[node_0.outputs[0]], + inputs=[node_0.o()], num_outputs=1, name="node_1", ) graph = ir.Graph( inputs=node_0.inputs, # type: ignore - outputs=[node_1.outputs[0]], + outputs=[node_1.o()], initializers=[ ir.Value(name="tensor1", const_value=tensor1), ir.Value(name="tensor2", const_value=tensor2), diff --git a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py index 5463cbad..1fd5a1a4 100644 --- a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py +++ b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py @@ -30,25 +30,41 @@ def test_pass_with_clear_metadata_and_docstring(self): ) mul_node = ir.node( "Mul", - inputs=[add_node.outputs[0], inputs[1]], + inputs=[add_node.o(), inputs[1]], num_outputs=1, metadata_props={"mul_key": "mul_value"}, doc_string="This is a Mul node", ) - func_inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ), - ir.Value( - name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ), - ] function = ir.Function( graph=ir.Graph( name="my_function", - inputs=func_inputs, - outputs=mul_node.outputs, - nodes=[add_node, mul_node], + inputs=[ + input_a := ir.Value( + name="input_a", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 3)), + ), + input_b := ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 3)), + ), + ], + nodes=[ + add_node_func := ir.node( + "Add", + inputs=[input_a, input_b], + metadata_props={"add_key": "add_value"}, + doc_string="This is an Add node", + ), + mul_node_func := ir.node( + "Mul", + inputs=[add_node_func.o(), input_b], + metadata_props={"mul_key": "mul_value"}, + doc_string="This is a Mul node", + ), + ], + outputs=mul_node_func.outputs, opset_imports={"": 20}, doc_string="This is a function docstring", metadata_props={"function_key": "function_value"}, @@ -57,6 +73,14 @@ def test_pass_with_clear_metadata_and_docstring(self): domain="my_domain", attributes=[], ) + func_node = ir.node( + "my_function", + inputs=[inputs[0], mul_node.o()], + domain="my_domain", + metadata_props={"mul_key": "mul_value"}, + doc_string="This is a Mul node", + ) + # TODO(justinchuby): This graph is broken. The output of the function cannot be a input to a node # Create a model with the graph and function constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir.DataType.FLOAT.numpy())) const_node = ir.node( @@ -69,7 +93,7 @@ def test_pass_with_clear_metadata_and_docstring(self): ) sub_node = ir.node( "Sub", - inputs=[function.outputs[0], const_node.outputs[0]], + inputs=[func_node.o(), const_node.o()], num_outputs=1, metadata_props={"sub_key": "sub_value"}, doc_string="This is a Sub node", diff --git a/src/onnx_ir/passes/common/constant_manipulation.py b/src/onnx_ir/passes/common/constant_manipulation.py index da104118..94ba2dfd 100644 --- a/src/onnx_ir/passes/common/constant_manipulation.py +++ b/src/onnx_ir/passes/common/constant_manipulation.py @@ -41,7 +41,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: assert node.graph is not None if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"): continue - if node.outputs[0].is_graph_output(): + if node.o().is_graph_output(): logger.debug( "Constant node '%s' is used as output, so it can't be lifted.", node.name ) @@ -54,7 +54,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: continue attr_name, attr_value = next(iter(node.attributes.items())) - initializer_name = node.outputs[0].name + initializer_name = node.o().name assert initializer_name is not None assert isinstance(attr_value, ir.Attr) tensor = self._constant_node_attribute_to_tensor( @@ -73,7 +73,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: assert node.graph is not None node.graph.register_initializer(initializer) # Replace the constant node with the initializer - ir.convenience.replace_all_uses_with(node.outputs[0], initializer) + ir.convenience.replace_all_uses_with(node.o(), initializer) node.graph.remove(node, safe=True) count += 1 logger.debug( diff --git a/src/onnx_ir/passes/common/constant_manipulation_test.py b/src/onnx_ir/passes/common/constant_manipulation_test.py index f3862917..8244829b 100644 --- a/src/onnx_ir/passes/common/constant_manipulation_test.py +++ b/src/onnx_ir/passes/common/constant_manipulation_test.py @@ -36,8 +36,8 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers( const_node = ir.node( "Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1 ) - add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]]) - mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) + add_node = ir.node("Add", inputs=[inputs[0], const_node.o()]) + mul_node = ir.node("Mul", inputs=[add_node.o(), inputs[1]]) model = ir.Model( graph=ir.Graph( @@ -92,10 +92,10 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( ) # then branch adds the constant to the input # else branch multiplies the input by the constant - add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]]) + add_node = ir.node("Add", inputs=[input_value, then_const_node.o()]) then_graph = ir.Graph( inputs=[], - outputs=[add_node.outputs[0]], + outputs=[add_node.o()], nodes=[then_const_node, add_node], opset_imports={"": 20}, ) @@ -103,10 +103,10 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( else_const_node = ir.node( "Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1 ) - mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]]) + mul_node = ir.node("Mul", inputs=[input_value, else_const_node.o()]) else_graph = ir.Graph( inputs=[], - outputs=[mul_node.outputs[0]], + outputs=[mul_node.o()], nodes=[else_const_node, mul_node], opset_imports={"": 20}, ) @@ -178,15 +178,13 @@ def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( attributes={constant_attribute: constant_value}, num_outputs=1, ) - identity_node_constant = ir.node( - "Identity", inputs=[const_node.outputs[0]], num_outputs=1 - ) + identity_node_constant = ir.node("Identity", inputs=[const_node.o()], num_outputs=1) identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1) model = ir.Model( graph=ir.Graph( inputs=[input_value], - outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]], + outputs=[identity_node_input.o(), identity_node_constant.o()], nodes=[identity_node_input, const_node, identity_node_constant], opset_imports={"": 20}, ), @@ -232,7 +230,7 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self): model = ir.Model( graph=ir.Graph( inputs=[input_value], - outputs=[identity_node_input.outputs[0], const_node.outputs[0]], + outputs=[identity_node_input.o(), const_node.o()], nodes=[identity_node_input, const_node], opset_imports={"": 20}, ), @@ -272,7 +270,7 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( add_node = ir.node("Add", inputs=[input_value, then_initializer_value]) then_graph = ir.Graph( inputs=[], - outputs=[add_node.outputs[0]], + outputs=[add_node.o()], nodes=[add_node], opset_imports={"": 20}, initializers=[then_initializer_value], @@ -287,7 +285,7 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph( mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value]) else_graph = ir.Graph( inputs=[], - outputs=[mul_node.outputs[0]], + outputs=[mul_node.o()], nodes=[mul_node], opset_imports={"": 20}, initializers=[else_initializer_value], @@ -351,7 +349,7 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph( # The initializer is also an input. We don't lift it to the main graph # to preserve the graph signature inputs=[then_initializer_value], - outputs=[add_node.outputs[0]], + outputs=[add_node.o()], nodes=[add_node], opset_imports={"": 20}, initializers=[then_initializer_value], @@ -366,7 +364,7 @@ def test_pass_does_not_lift_initialized_inputs_in_subgraph( mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value]) else_graph = ir.Graph( inputs=[], - outputs=[mul_node.outputs[0]], + outputs=[mul_node.o()], nodes=[mul_node], opset_imports={"": 20}, initializers=[else_initializer_value],