Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,9 +716,11 @@ def version_1(cls, ctx, node, **kwargs):
# T output = Split(int32 split_dim, T value, @int num_split)
# T outputs = Split(T input, @INT axis, @INTS split)
split_dims = node.inputs[0].get_tensor_value()
new_split_dims = split_dims + len(node.output_shapes[0]) if split_dims < 0 else split_dims
new_split_dims = 1 if new_split_dims == 3 else new_split_dims
ctx.remove_input(node, node.input[0], 0)
node.set_attr("num_outputs", node.get_attr_int("num_split"))
node.set_attr("axis", split_dims)
node.set_attr("axis", new_split_dims)

@classmethod
def version_2(cls, ctx, node, **kwargs):
Expand Down
37 changes: 36 additions & 1 deletion tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,27 @@ def _switch_transpose_and_node(self, node, trans, update_shape=True):
self._g.set_shape(node.output[0], new_shape)
self._g.set_shape(trans.output[0], shape)
return True

# this is for the case where node has multiple outputs. e.g. split node.
def _switch_transpose_and_node_with_multiple_outputs(self, node, trans, update_shape=True):
input_index = self._get_input_index_for_trans(node, trans)
for idx,_output in enumerate(node.output):
shape = self._g.get_shape(_output)
nxt_nodes = self._g.find_output_consumers(_output)
if idx == 0:
transpose = trans
self._g.replace_input(node, node.input[input_index], transpose.input[0], input_index)
self._g.replace_input(trans, trans.input[0], _output, 0)
else:
transpose = self._g.make_node("Transpose", [_output], attr={"perm": trans.get_attr_value("perm")})
for nxt_node in nxt_nodes:
self._g.replace_input(nxt_node, _output, transpose.output[0])

if update_shape and shape:
perm_inv = invert_perm(transpose.get_attr_value("perm"))
new_shape = [shape[i] for i in perm_inv]
self._g.set_shape(_output, new_shape)
self._g.set_shape(transpose.output[0], shape)
return True
# if return value is True, then it means Transpose is handled as designed
# otherwise, it means that we skip handling since it is not in our support set
def _handle_nhwc_tranpose(self, trans):
Expand Down Expand Up @@ -694,6 +714,21 @@ def _split_handler(self, trans, node):
new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np)
self._g.replace_inputs(node, [node.input[0], new_axes_const.output[0]])
return True
# handling having branches
if len(node.output) > 1:
trans_rank = get_transpose_rank(trans)
axes = node.get_attr_value("axis", 0)
perm = trans.get_attr("perm").ints
axes = [axes + trans_rank if axes < 0 else axes]
if split:
new_axes_np = np.array(split, dtype=np.int64)
new_axes_const = self._g.make_const(utils.make_name(node.inputs[1].name), new_axes_np)
# [Transpose -> Split -> next_nodes] -> [Split -> Transpose -> next_nodes]
if not self._switch_transpose_and_node_with_multiple_outputs(node, trans, 1):
return False
new_axes = [perm[a] for a in axes]
node.set_attr("axes", new_axes)
return True
return False

def _unsqueeze_handler(self, trans, node):
Expand Down
Loading