Skip to content

Commit 3a2969c

Browse files
committed
unifying the squeee layer
1 parent 0e4ee77 commit 3a2969c

File tree

2 files changed

+14
-17
lines changed

2 files changed

+14
-17
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,7 @@ def expand_boolean_indices(
8787
# nonzero returns shape [N, dims], we need to extract dim i
8888
if len(indices) == 1:
8989
# x[mask] — 1D mask
90-
squeeze_layer = ctx.net.add_shuffle(nonzero_indices)
91-
squeeze_layer.reshape_dims = (-1,)
92-
set_layer_name(
93-
squeeze_layer,
94-
target,
95-
name + f"_bool_nonzero_squeeze_{i}",
96-
source_ir,
97-
)
98-
squeezed_index = squeeze_layer.get_output(0)
99-
new_indices.append(squeezed_index)
90+
to_squeeze = nonzero_indices
10091
else:
10192
# Advanced multi-axis mask: extract index i from shape [N, D]
10293
gather_axis = 1 # dim index
@@ -108,11 +99,17 @@ def expand_boolean_indices(
10899
set_layer_name(
109100
gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir
110101
)
111-
extracted_index = gather_layer.get_output(0)
112-
squeeze_layer = ctx.net.add_shuffle(extracted_index)
113-
squeeze_layer.reshape_dims = (-1,)
114-
squeezed_index = squeeze_layer.get_output(0)
115-
new_indices.append(squeezed_index)
102+
to_squeeze = gather_layer.get_output(0)
103+
squeeze_layer = ctx.net.add_shuffle(to_squeeze)
104+
squeeze_layer.reshape_dims = (-1,)
105+
set_layer_name(
106+
squeeze_layer,
107+
target,
108+
name + f"_bool_mask_squeeze_{i}",
109+
source_ir,
110+
)
111+
squeezed_index = squeeze_layer.get_output(0)
112+
new_indices.append(squeezed_index)
116113
else:
117114
new_indices.append(ind)
118115
return new_indices

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ class TestIndexConstantConverter(DispatchTestCase):
8686
"mask_index_multi_axis",
8787
[
8888
None,
89-
torch.tensor([[True, False, False, True]]), # axis 1
89+
torch.tensor([True, False]), # axis 1
9090
None,
9191
torch.tensor([True, False]), # axis 3
9292
],
93-
torch.randn(2, 4, 4, 2),
93+
torch.randn(2, 2, 2, 2),
9494
),
9595
]
9696
)

0 commit comments

Comments
 (0)