From 3189e48ebfc43b825507f22e73795a3ad359815f Mon Sep 17 00:00:00 2001 From: Adrian Bauer Date: Thu, 12 Jun 2025 08:22:23 +0200 Subject: [PATCH 1/3] fixes #2538 --- .../mil/frontend/torch/test/test_passes.py | 50 +++++++++++++++++++ .../mil/frontend/torch/torchir_passes.py | 21 ++++---- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_passes.py b/coremltools/converters/mil/frontend/torch/test/test_passes.py index 4401ebbfb..6d6353c2d 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_passes.py +++ b/coremltools/converters/mil/frontend/torch/test/test_passes.py @@ -18,6 +18,7 @@ flatten_graph_input_values, flatten_graph_output_values, transform_inplace_ops, + remove_getattr_nodes ) import coremltools as ct @@ -405,3 +406,52 @@ def forward(self, x): y_cm = ct_model.predict({'x': x})['y'] assert((y_cm == np.zeros(shape)).all()) + + + @staticmethod + def test_remove_getattr_nodes_immediate_output(): + graph_nodes = [ + InternalTorchIRNode( + inputs=["self"], + attr={"name": "const_out2", "value": None}, + outputs=["const_out2"], + kind="getattr", + ), + InternalTorchIRNode( + inputs=["self"], + attr={"name": "const_out1", "value": None}, + outputs=["const_out1"], + kind="getattr", + ), + InternalTorchIRNode( + inputs=["const_out1", "const_out2"], + attr={"value": None}, + outputs=["3"], + kind="tupleconstruct", + ), + ] + const2 = torch.tensor([5., 6., 7., 8.]) + const1 = torch.tensor([1., 2., 3., 4.]) + graph_params = {'const_out2': const2, + 'const_out1': const1} + graph_inputs = [] + graph_outputs = ["const_out1", "const_out2"] + + graph = InternalTorchIRGraph( + nodes=graph_nodes, + params=graph_params, + inputs=graph_inputs, + outputs=graph_outputs, + ) + + for node in graph.nodes: + node.parent = graph + + remove_getattr_nodes(graph) + + np.testing.assert_equal(graph.nodes[0].kind, "constant") + np.testing.assert_equal(graph.nodes[1].kind, "constant") + np.testing.assert_equal(graph.nodes[2].kind, "tupleconstruct") + np.testing.assert_allclose(graph.nodes[0].attr["value"], const2) + np.testing.assert_allclose(graph.nodes[1].attr["value"], const1) + diff --git a/coremltools/converters/mil/frontend/torch/torchir_passes.py b/coremltools/converters/mil/frontend/torch/torchir_passes.py index cf784cdf0..5fdbf95c1 100644 --- a/coremltools/converters/mil/frontend/torch/torchir_passes.py +++ b/coremltools/converters/mil/frontend/torch/torchir_passes.py @@ -231,10 +231,9 @@ def forward(self, x): def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None: """ - Remove the getattr nodes in the graph + Remove the getattr nodes in the graph that are not output nodes """ - getattr_nodes = [] new_nodes = [] for node in graph.nodes: @@ -243,16 +242,20 @@ def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None: remove_getattr_nodes(block) if node.kind == "getattr": - getattr_nodes.append(node) + if node.name in graph.outputs: + # create and add new constant node + new_nodes.append( + InternalTorchIRNode( + inputs=[], + outputs=node.outputs, + kind="constant", + name="internal_immediate_output_attr", + attr={"value": node.parent.params[node.name]} + ) + ) else: new_nodes.append(node) - # check the getattr nodes not in the outputs - for node in getattr_nodes: - if node.name in graph.outputs: - raise RuntimeError("{} should not be in the graph outputs.".format(node.name)) - - # remove the getattr nodes graph.nodes = new_nodes From 7db53389169e1abab6fd9826cbc33fc572f7f7e3 Mon Sep 17 00:00:00 2001 From: Adrian Bauer <61182488+tritolol@users.noreply.github.com> Date: Thu, 12 Jun 2025 08:58:39 +0200 Subject: [PATCH 2/3] Update torchir_passes.py remove_getattr_nodes() description --- coremltools/converters/mil/frontend/torch/torchir_passes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/coremltools/converters/mil/frontend/torch/torchir_passes.py b/coremltools/converters/mil/frontend/torch/torchir_passes.py index 5fdbf95c1..e910b08d1 100644 --- a/coremltools/converters/mil/frontend/torch/torchir_passes.py +++ b/coremltools/converters/mil/frontend/torch/torchir_passes.py @@ -231,7 +231,8 @@ def forward(self, x): def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None: """ - Remove the getattr nodes in the graph that are not output nodes + Remove the getattr nodes in the graph + If they are output nodes, convert them to constant nodes """ new_nodes = [] From 5d51d4f2f1df842b5943e3b7542e2dffb74d728f Mon Sep 17 00:00:00 2001 From: Adrian Bauer Date: Thu, 12 Jun 2025 13:51:35 +0200 Subject: [PATCH 3/3] added .detach() to connst value, added new unit test --- .../torch/test/test_torch_conversion_api.py | 26 +++++++++++++++++++ .../mil/frontend/torch/torchir_passes.py | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py b/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py index 6bf2cf6f0..9b11dc367 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py @@ -929,6 +929,32 @@ def forward( past_kv_len += 1 + @staticmethod + def test_immediate_return_getattr_model(): + class ImmediateReturnGetAttrModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("my_constant_output", torch.tensor([1.0, 2.0, 3.0, 4.0])) + self.register_buffer("my_constant_output2", torch.tensor([5.0, 6.0, 7.0, 8.0])) + + def forward(self, x): + # x is a dummy input, not used + return self.my_constant_output, self.my_constant_output2 + + model = ImmediateReturnGetAttrModel() + model.eval() + dummy_input = torch.zeros(1) # Dummy input for tracing + traced_model = torch.jit.trace(model, example_inputs=(dummy_input,)) + mlmodel = ct.convert( + traced_model, + inputs=[ct.TensorType(shape=(1,))], + convert_to='mlprogram' + ) + outputs = mlmodel.predict({"x": np.zeros(1)}) + assert "my_constant_output" in outputs + assert "my_constant_output2" in outputs + + ############################################################################### # Note: Stress tests for PyTorch input / output types ############################################################################### diff --git a/coremltools/converters/mil/frontend/torch/torchir_passes.py b/coremltools/converters/mil/frontend/torch/torchir_passes.py index e910b08d1..d6e6fa203 100644 --- a/coremltools/converters/mil/frontend/torch/torchir_passes.py +++ b/coremltools/converters/mil/frontend/torch/torchir_passes.py @@ -251,7 +251,7 @@ def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None: outputs=node.outputs, kind="constant", name="internal_immediate_output_attr", - attr={"value": node.parent.params[node.name]} + attr={"value": node.parent.params[node.name].detach()} ) ) else: