Skip to content

Commit c286767

Browse files
committed
Supported bool mask indicies
1 parent 6ea89ae commit c286767

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,9 @@ def index_put_converter(
550550
accumulate: bool = False,
551551
) -> TRTTensor:
552552
# Convert 'input_indices' to TRT tensors (or keep None as is)
553+
input_indices = expand_boolean_indices(
554+
ctx, target, source_ir, name, input_tensor, input_indices
555+
)
553556
indices: List[Optional[Union[TRTTensor, None]]] = []
554557
for i, idx in enumerate(input_indices):
555558
if idx is None:
@@ -828,20 +831,15 @@ def index_put_converter(
828831
) + list(values.shape)
829832
broadcast_shape = []
830833
for exp_dim, val_dim in zip(expected_shape, values_shape_padded):
831-
if val_dim == 1 or exp_dim == val_dim:
834+
if val_dim == DYNAMIC_DIM or exp_dim == DYNAMIC_DIM:
835+
broadcast_shape.append(-1)
836+
elif val_dim == 1 or exp_dim == val_dim:
832837
broadcast_shape.append(exp_dim)
833838
else:
834839
raise ValueError(
835840
f"Cannot broadcast {values.shape} to {expected_shape}"
836841
)
837-
# values_reshaped = impl.shuffle.reshape(
838-
# ctx,
839-
# target,
840-
# source_ir,
841-
# f"{name}_reshape_values",
842-
# values,
843-
# tuple(broadcast_shape),
844-
# )
842+
845843
values_expanded = impl.slice.expand(
846844
ctx,
847845
target,

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,43 @@ def forward(self, x, y, z, a, b):
328328
result = trt_mod(*inputs)
329329
assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)
330330

331+
def test_bool_mask_test(self):
332+
333+
source_tensor = torch.ones([5, 10], dtype=torch.float32).cuda()
334+
indices_tensor = torch.tensor([False, False, True, False, True])
335+
value_tensor = torch.zeros([2, 10], dtype=torch.float32).cuda()
336+
337+
dim1 = torch.export.Dim("dim1", min=1, max=5)
338+
dim2 = torch.export.Dim("dim2", min=1, max=5)
339+
340+
class TestIndexPut(torch.nn.Module):
341+
def forward(self, source_tensor, indices_tensor, value_tensor):
342+
source_tensor[indices_tensor] = value_tensor
343+
return source_tensor
344+
345+
model = TestIndexPut()
346+
torch_output = model.forward(source_tensor, indices_tensor, value_tensor)
347+
348+
ep = torch.export.export(
349+
model,
350+
(source_tensor, indices_tensor, value_tensor),
351+
dynamic_shapes=({0: dim1}, {0: dim1}, {0: dim2}),
352+
)
353+
with torchtrt.dynamo.Debugger(log_level="debug"):
354+
trt_engine = torchtrt.dynamo.compile(
355+
ep,
356+
inputs=(source_tensor, indices_tensor, value_tensor),
357+
enabled_precisions={torch.float32},
358+
min_block_size=1,
359+
use_explicit_typing=False,
360+
use_fp32_acc=False,
361+
disable_tf32=True,
362+
use_python_runtime=True,
363+
)
364+
result = trt_engine(source_tensor, indices_tensor, value_tensor)
365+
366+
torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)
367+
331368

332369
if __name__ == "__main__":
333370
run_tests()

0 commit comments

Comments
 (0)