From 59978b6e21fc64158c7d6167293a6fb57ff4c3d2 Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 27 Jul 2023 16:03:56 +0900 Subject: [PATCH 1/2] fix --- .../mil/frontend/torch/test/test_passes.py | 22 +++++++++++++++++++ .../mil/frontend/torch/torchir_passes.py | 19 +++++++--------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_passes.py b/coremltools/converters/mil/frontend/torch/test/test_passes.py index 4401ebbfb..03582d81b 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_passes.py +++ b/coremltools/converters/mil/frontend/torch/test/test_passes.py @@ -405,3 +405,25 @@ def forward(self, x): y_cm = ct_model.predict({'x': x})['y'] assert((y_cm == np.zeros(shape)).all()) + + @staticmethod + def test_inpace_op_from_add(): + class Net(torch.nn.Module): + def forward(self, x): + y = torch.empty(x.shape).to(torch.int32) + y.fill_(0) + return y + + shape = (2, 3) + x = torch.rand(*shape) + traced_fn = torch.jit.trace(Net(), x).eval() + + ct_model = ct.convert( + traced_fn, + inputs=[ct.TensorType(shape=shape)], + outputs=[ct.TensorType(name="y", dtype=np.int32)], + source="pytorch", + ) + y_cm = ct_model.predict({'x': x})['y'] + + assert((y_cm == np.zeros(shape)).all()) diff --git a/coremltools/converters/mil/frontend/torch/torchir_passes.py b/coremltools/converters/mil/frontend/torch/torchir_passes.py index d066d9a9d..100fe66a1 100644 --- a/coremltools/converters/mil/frontend/torch/torchir_passes.py +++ b/coremltools/converters/mil/frontend/torch/torchir_passes.py @@ -135,17 +135,7 @@ def _construct_nodes_to_fuse_inputs(nodes_to_fuse): tensor_to_node_sequence_mapping.pop(node_input) node_sequence.append(node) tensor_to_node_sequence_mapping[node_output] = node_sequence - - if node.kind == "to": - node_input = node.inputs[0] - if node_input in tensor_to_node_sequence_mapping: - # update the mapping - node_output = node.outputs[0] - val = tensor_to_node_sequence_mapping[node_input] - del tensor_to_node_sequence_mapping[node_input] - tensor_to_node_sequence_mapping[node_output] = val - - if node.kind in ("copy_", "fill_"): + elif node.kind in ("copy_", "fill_"): node_input = node.inputs[0] if node_input not in tensor_to_node_sequence_mapping: raise ValueError("No matching select or slice.") @@ -176,6 +166,13 @@ def _construct_nodes_to_fuse_inputs(nodes_to_fuse): blocks=[], ) graph.nodes[i] = tensor_assign_node + elif node.inputs: + node_input = node.inputs[0] + if node_input in tensor_to_node_sequence_mapping: + # update the mapping + node_output = node.outputs[0] + val = tensor_to_node_sequence_mapping[node_input] + tensor_to_node_sequence_mapping[node_output] = val # modify the graph outputs if it is effected by this graph pass for idx in range(len(graph.outputs)): From 72844d809ea760100dab6a8b83d5f7c9df0e6ae6 Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 27 Jul 2023 16:06:07 +0900 Subject: [PATCH 2/2] fix ttst --- coremltools/converters/mil/frontend/torch/test/test_passes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_passes.py b/coremltools/converters/mil/frontend/torch/test/test_passes.py index 03582d81b..f2c83fe9e 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_passes.py +++ b/coremltools/converters/mil/frontend/torch/test/test_passes.py @@ -410,7 +410,7 @@ def forward(self, x): def test_inpace_op_from_add(): class Net(torch.nn.Module): def forward(self, x): - y = torch.empty(x.shape).to(torch.int32) + y = torch.empty(x.shape) + 1 y.fill_(0) return y