From c60e9b6f48d876be98cd2fcb81f7304908a0df52 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 17 Jul 2025 11:13:45 -0700 Subject: [PATCH 1/6] Index converter dynamic cases fix --- tests/py/dynamo/conversion/test_index_aten.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 8e21f945dc..fc4a70b1ff 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -168,7 +168,31 @@ def forward(self, input): dtype=torch.float32, ), ] - self.run_test_with_dynamic_shape(TestModule(), input_specs) + self.run_test_with_dynamic_shape( + TestModule(), input_specs, use_dynamo_tracer=True + ) + + +class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase): + def test_index_input_non_dynamic_index_dynamic(self): + class TestIndexWithRuntimeIndex(torch.nn.Module): + def forward(self, x): + mask = x > 0 + idx = torch.nonzero(mask, as_tuple=True) + return torch.ops.aten.index.Tensor(x, idx) + + input_specs = [ + Input( + min_shape=(2, 2), + opt_shape=(2, 2), + max_shape=(8, 8), + dtype=torch.float32, + ), + ] + # In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True + self.run_test_with_dynamic_shape( + TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True + ) if __name__ == "__main__": From d9cda495957f51e4a53491156e7a88ae7655b3c4 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 31 Jul 2025 15:53:21 -0700 Subject: [PATCH 2/6] support for boolean indices --- .../dynamo/conversion/aten_ops_converters.py | 6 +- .../dynamo/conversion/impl/select.py | 62 ++++++++++++++++++- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 65923c7dac..178caa17c1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -414,7 +414,11 @@ def index_dtype_validator( for ind in index: if ind is not None: val = ind.meta.get("val") - if val is not None and val.dtype not in (torch.int32, torch.int64): + if val is not None and val.dtype not in ( + torch.int32, + torch.int64, + torch.bool, + ): return False return True diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index c4d44a07ea..10a8332538 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -51,6 +51,65 @@ def select( return layer.get_output(0) +def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: + if isinstance(tensor, (TRTTensor)): + val = tensor.meta.get("val") + if val is not None and val.dtype is torch.bool: + return True + return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool + + +def expand_boolean_indices( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], +) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]: + for i, ind in enumerate(indices): + if ind is not None and is_boolean_tensor(ind): + _LOGGER.debug( + f"Boolean index detected at position {i}, converting with nonzero()" + ) + + mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}") + + nonzero_layer = ctx.net.add_non_zero(mask_tensor) + set_layer_name( + nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir + ) + nonzero_indices = nonzero_layer.get_output(0) + + # nonzero returns shape [N, dims], we need to extract dim i + if len(indices) == 1: + # x[mask] — 1D mask + squeeze_layer = ctx.net.add_shuffle(nonzero_indices) + squeeze_layer.reshape_dims = (-1,) + set_layer_name( + squeeze_layer, + target, + name + f"_bool_nonzero_squeeze_{i}", + source_ir, + ) + squeezed_index = squeeze_layer.get_output(0) + ind = squeezed_index + else: + # Advanced multi-axis mask: extract index i from shape [N, D] + gather_axis = 1 # dim index + gather_layer = ctx.net.add_gather( + nonzero_indices, + get_trt_tensor(ctx, i, name + f"_dim_index_{i}"), + gather_axis, + ) + set_layer_name( + gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir + ) + extracted_index = gather_layer.get_output(0) + ind = extracted_index + return indices + + def index( ctx: ConversionContext, target: Target, @@ -61,8 +120,6 @@ def index( ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] - # check if the input is dynamic - dynamic_shape = has_dynamic_shape(input.shape) # is_numpy is a flag to specify if all the indices are numpy or torchTensor. # If any is not this flag will be set to False _LOGGER.debug( @@ -76,6 +133,7 @@ def index( # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None + indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices) for i, ind in enumerate(indices): if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") From 03cf9034916bfb19b452b1671ca6827d0a5dd029 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 21 Aug 2025 17:28:39 -0700 Subject: [PATCH 3/6] mask test cases and correction --- .../dynamo/conversion/aten_ops_converters.py | 1 + .../dynamo/conversion/impl/select.py | 22 +++++++++++++------ tests/py/dynamo/conversion/test_index_aten.py | 10 +++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 178caa17c1..e9be9c9b89 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -427,6 +427,7 @@ def index_dtype_validator( torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator, supports_dynamic_shapes=True, + requires_output_allocator=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 10a8332538..85de72893d 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -14,7 +14,6 @@ cast_trt_tensor, get_positive_dim, get_trt_tensor, - has_dynamic_shape, set_layer_name, to_numpy, ) @@ -52,10 +51,14 @@ def select( def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: - if isinstance(tensor, (TRTTensor)): + if isinstance(tensor, (torch.Tensor, np.ndarray, TRTTensor)): + return bool(tensor.dtype == torch.bool) + # when index is a node + else: val = tensor.meta.get("val") if val is not None and val.dtype is torch.bool: return True + return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool @@ -67,12 +70,12 @@ def expand_boolean_indices( input: TRTTensor, indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], ) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]: + new_indices = [] for i, ind in enumerate(indices): if ind is not None and is_boolean_tensor(ind): _LOGGER.debug( f"Boolean index detected at position {i}, converting with nonzero()" ) - mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}") nonzero_layer = ctx.net.add_non_zero(mask_tensor) @@ -93,7 +96,7 @@ def expand_boolean_indices( source_ir, ) squeezed_index = squeeze_layer.get_output(0) - ind = squeezed_index + new_indices.append(squeezed_index) else: # Advanced multi-axis mask: extract index i from shape [N, D] gather_axis = 1 # dim index @@ -106,8 +109,13 @@ def expand_boolean_indices( gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir ) extracted_index = gather_layer.get_output(0) - ind = extracted_index - return indices + squeeze_layer = ctx.net.add_shuffle(extracted_index) + squeeze_layer.reshape_dims = (-1,) + squeezed_index = squeeze_layer.get_output(0) + new_indices.append(squeezed_index) + else: + new_indices.append(ind) + return new_indices def index( @@ -125,6 +133,7 @@ def index( _LOGGER.debug( "Determining whether aten.index constant-index optimization can be invoked" ) + indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices) is_numpy = all( isinstance(ind, (torch.Tensor, np.ndarray)) for ind in indices @@ -133,7 +142,6 @@ def index( # here we need to check if all the index are broadcastable # if no, then we need to broadcast last_index = None - indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices) for i, ind in enumerate(indices): if ind is not None: _LOGGER.debug(f"Shape of {i} index is {ind.shape}") diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index fc4a70b1ff..5aa44c02c4 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -71,6 +71,16 @@ class TestIndexConstantConverter(DispatchTestCase): [None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])], torch.randn(2, 4, 4, 2), ), + ( + "mask_index_three_dim", + [None, torch.tensor([True, False]), None], + torch.randn(2, 2, 2), + ), + ( + "mask_index_two_dim", + [torch.tensor([True, False])], + torch.randn(2, 2), + ), ] ) def test_index_constant(self, _, index, input): From 0e4ee771f3edb1b053ce7e6e567ab86331d7798c Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 22 Aug 2025 16:26:04 -0700 Subject: [PATCH 4/6] adding the discontinuous mask indices case --- tests/py/dynamo/conversion/test_index_aten.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 5aa44c02c4..f7278f84a6 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -81,6 +81,17 @@ class TestIndexConstantConverter(DispatchTestCase): [torch.tensor([True, False])], torch.randn(2, 2), ), + ( + # covers multi axis and discontinuous indices + "mask_index_multi_axis", + [ + None, + torch.tensor([[True, False, False, True]]), # axis 1 + None, + torch.tensor([True, False]), # axis 3 + ], + torch.randn(2, 4, 4, 2), + ), ] ) def test_index_constant(self, _, index, input): From 3a2969cbd37ad9b6f65a923e32158c773a1798e7 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 22 Aug 2025 16:39:33 -0700 Subject: [PATCH 5/6] unifying the squeee layer --- .../dynamo/conversion/impl/select.py | 27 +++++++++---------- tests/py/dynamo/conversion/test_index_aten.py | 4 +-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index 85de72893d..ded50519ad 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -87,16 +87,7 @@ def expand_boolean_indices( # nonzero returns shape [N, dims], we need to extract dim i if len(indices) == 1: # x[mask] — 1D mask - squeeze_layer = ctx.net.add_shuffle(nonzero_indices) - squeeze_layer.reshape_dims = (-1,) - set_layer_name( - squeeze_layer, - target, - name + f"_bool_nonzero_squeeze_{i}", - source_ir, - ) - squeezed_index = squeeze_layer.get_output(0) - new_indices.append(squeezed_index) + to_squeeze = nonzero_indices else: # Advanced multi-axis mask: extract index i from shape [N, D] gather_axis = 1 # dim index @@ -108,11 +99,17 @@ def expand_boolean_indices( set_layer_name( gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir ) - extracted_index = gather_layer.get_output(0) - squeeze_layer = ctx.net.add_shuffle(extracted_index) - squeeze_layer.reshape_dims = (-1,) - squeezed_index = squeeze_layer.get_output(0) - new_indices.append(squeezed_index) + to_squeeze = gather_layer.get_output(0) + squeeze_layer = ctx.net.add_shuffle(to_squeeze) + squeeze_layer.reshape_dims = (-1,) + set_layer_name( + squeeze_layer, + target, + name + f"_bool_mask_squeeze_{i}", + source_ir, + ) + squeezed_index = squeeze_layer.get_output(0) + new_indices.append(squeezed_index) else: new_indices.append(ind) return new_indices diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index f7278f84a6..a98a9cba76 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -86,11 +86,11 @@ class TestIndexConstantConverter(DispatchTestCase): "mask_index_multi_axis", [ None, - torch.tensor([[True, False, False, True]]), # axis 1 + torch.tensor([True, False]), # axis 1 None, torch.tensor([True, False]), # axis 3 ], - torch.randn(2, 4, 4, 2), + torch.randn(2, 2, 2, 2), ), ] ) From e659a901242fa69192c869eedabf417f216bdebf Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 29 Aug 2025 15:04:57 -0700 Subject: [PATCH 6/6] addressing review comments --- .../dynamo/conversion/impl/select.py | 12 +++++++--- tests/py/dynamo/conversion/test_index_aten.py | 22 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index ded50519ad..6f4a812dd8 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -50,16 +50,22 @@ def select( return layer.get_output(0) -def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: - if isinstance(tensor, (torch.Tensor, np.ndarray, TRTTensor)): +def is_boolean_tensor( + tensor: Union[TRTTensor, np.ndarray, torch.Tensor, torch.fx.Node], +) -> bool: + if isinstance(tensor, torch.Tensor): return bool(tensor.dtype == torch.bool) + elif isinstance(tensor, np.ndarray): + return bool(tensor.dtype == np.bool_) + elif isinstance(tensor, TRTTensor): + return bool(tensor.dtype == trt.DataType.BOOL) # when index is a node else: val = tensor.meta.get("val") if val is not None and val.dtype is torch.bool: return True - return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool + return False def expand_boolean_indices( diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index a98a9cba76..05d86d382b 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -125,6 +125,17 @@ def forward(self, x, index0): [input, index0], ) + def test_index_zero_two_dim_ITensor_mask(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2) + index0 = torch.tensor([True, False]) + self.run_test(TestModule(), [input, index0], enable_passes=True) + def test_index_zero_index_three_dim_ITensor(self): class TestModule(nn.Module): def forward(self, x, index0): @@ -137,6 +148,17 @@ def forward(self, x, index0): index0 = index0.to(torch.int32) self.run_test(TestModule(), [input, index0]) + def test_index_zero_index_three_dim_mask_ITensor(self): + class TestModule(nn.Module): + def forward(self, x, index0): + indices = [None, index0, None] + out = torch.ops.aten.index.Tensor(x, indices) + return out + + input = torch.randn(2, 2, 2) + index0 = torch.tensor([True, False]) + self.run_test(TestModule(), [input, index0]) + class TestIndexDynamicConstantConverter(DispatchTestCase): @parameterized.expand(