Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,11 @@ def index_dtype_validator(
for ind in index:
if ind is not None:
val = ind.meta.get("val")
if val is not None and val.dtype not in (torch.int32, torch.int64):
if val is not None and val.dtype not in (
torch.int32,
torch.int64,
torch.bool,
):
return False
return True

Expand All @@ -423,6 +427,7 @@ def index_dtype_validator(
torch.ops.aten.index.Tensor,
capability_validator=index_dtype_validator,
supports_dynamic_shapes=True,
requires_output_allocator=True,
)
@enforce_tensor_types(
{
Expand Down
75 changes: 72 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
cast_trt_tensor,
get_positive_dim,
get_trt_tensor,
has_dynamic_shape,
set_layer_name,
to_numpy,
)
Expand Down Expand Up @@ -51,6 +50,77 @@ def select(
return layer.get_output(0)


def is_boolean_tensor(
tensor: Union[TRTTensor, np.ndarray, torch.Tensor, torch.fx.Node],
) -> bool:
if isinstance(tensor, torch.Tensor):
return bool(tensor.dtype == torch.bool)
elif isinstance(tensor, np.ndarray):
return bool(tensor.dtype == np.bool_)
elif isinstance(tensor, TRTTensor):
return bool(tensor.dtype == trt.DataType.BOOL)
# when index is a node
else:
val = tensor.meta.get("val")
if val is not None and val.dtype is torch.bool:
return True

return False


def expand_boolean_indices(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]:
new_indices = []
for i, ind in enumerate(indices):
if ind is not None and is_boolean_tensor(ind):
_LOGGER.debug(
f"Boolean index detected at position {i}, converting with nonzero()"
)
mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}")

nonzero_layer = ctx.net.add_non_zero(mask_tensor)
set_layer_name(
nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir
)
nonzero_indices = nonzero_layer.get_output(0)

# nonzero returns shape [N, dims], we need to extract dim i
if len(indices) == 1:
# x[mask] — 1D mask
to_squeeze = nonzero_indices
else:
# Advanced multi-axis mask: extract index i from shape [N, D]
gather_axis = 1 # dim index
gather_layer = ctx.net.add_gather(
nonzero_indices,
get_trt_tensor(ctx, i, name + f"_dim_index_{i}"),
gather_axis,
)
set_layer_name(
gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir
)
to_squeeze = gather_layer.get_output(0)
squeeze_layer = ctx.net.add_shuffle(to_squeeze)
squeeze_layer.reshape_dims = (-1,)
set_layer_name(
squeeze_layer,
target,
name + f"_bool_mask_squeeze_{i}",
source_ir,
)
squeezed_index = squeeze_layer.get_output(0)
new_indices.append(squeezed_index)
else:
new_indices.append(ind)
return new_indices


def index(
ctx: ConversionContext,
target: Target,
Expand All @@ -61,13 +131,12 @@ def index(
) -> TRTTensor:
adv_indx_indices = []
tensor_indices = []
# check if the input is dynamic
dynamic_shape = has_dynamic_shape(input.shape)
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
# If any is not this flag will be set to False
_LOGGER.debug(
"Determining whether aten.index constant-index optimization can be invoked"
)
indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices)
is_numpy = all(
isinstance(ind, (torch.Tensor, np.ndarray))
for ind in indices
Expand Down
69 changes: 68 additions & 1 deletion tests/py/dynamo/conversion/test_index_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,27 @@ class TestIndexConstantConverter(DispatchTestCase):
[None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])],
torch.randn(2, 4, 4, 2),
),
(
"mask_index_three_dim",
[None, torch.tensor([True, False]), None],
torch.randn(2, 2, 2),
),
(
"mask_index_two_dim",
[torch.tensor([True, False])],
torch.randn(2, 2),
),
(
# covers multi axis and discontinuous indices
"mask_index_multi_axis",
[
None,
torch.tensor([True, False]), # axis 1
None,
torch.tensor([True, False]), # axis 3
],
torch.randn(2, 2, 2, 2),
),
]
)
def test_index_constant(self, _, index, input):
Expand Down Expand Up @@ -104,6 +125,17 @@ def forward(self, x, index0):
[input, index0],
)

def test_index_zero_two_dim_ITensor_mask(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = torch.randn(2, 2)
index0 = torch.tensor([True, False])
self.run_test(TestModule(), [input, index0], enable_passes=True)

def test_index_zero_index_three_dim_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
Expand All @@ -116,6 +148,17 @@ def forward(self, x, index0):
index0 = index0.to(torch.int32)
self.run_test(TestModule(), [input, index0])

def test_index_zero_index_three_dim_mask_ITensor(self):
class TestModule(nn.Module):
def forward(self, x, index0):
indices = [None, index0, None]
out = torch.ops.aten.index.Tensor(x, indices)
return out

input = torch.randn(2, 2, 2)
index0 = torch.tensor([True, False])
self.run_test(TestModule(), [input, index0])


class TestIndexDynamicConstantConverter(DispatchTestCase):
@parameterized.expand(
Expand Down Expand Up @@ -168,7 +211,31 @@ def forward(self, input):
dtype=torch.float32,
),
]
self.run_test_with_dynamic_shape(TestModule(), input_specs)
self.run_test_with_dynamic_shape(
TestModule(), input_specs, use_dynamo_tracer=True
)


class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase):
def test_index_input_non_dynamic_index_dynamic(self):
class TestIndexWithRuntimeIndex(torch.nn.Module):
def forward(self, x):
mask = x > 0
idx = torch.nonzero(mask, as_tuple=True)
return torch.ops.aten.index.Tensor(x, idx)

input_specs = [
Input(
min_shape=(2, 2),
opt_shape=(2, 2),
max_shape=(8, 8),
dtype=torch.float32,
),
]
# In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True
self.run_test_with_dynamic_shape(
TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True
)


if __name__ == "__main__":
Expand Down
Loading