Skip to content

Commit 22edfe6

Browse files
committed
Covered the bool mask cases
1 parent fa53986 commit 22edfe6

File tree

4 files changed

+110
-1
lines changed

4 files changed

+110
-1
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,9 @@ def index_put_converter(
759759
) + list(values.shape)
760760
broadcast_shape = []
761761
for exp_dim, val_dim in zip(expected_shape, values_shape_padded):
762-
if val_dim == 1 or exp_dim == val_dim:
762+
if val_dim == DYNAMIC_DIM or exp_dim == DYNAMIC_DIM:
763+
broadcast_shape.append(-1)
764+
elif val_dim == 1 or exp_dim == val_dim:
763765
broadcast_shape.append(exp_dim)
764766
else:
765767
raise ValueError(

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .complex_graph_rewrite import complex_graph_detection
99
from .constant_folding import constant_fold
1010
from .fuse_prims_broadcast import fuse_prims_broadcast
11+
from .index_put_replace_bool_with_indices import index_put_replace_bool_with_indices
1112
from .pass_manager import DynamoPassManager
1213
from .remove_assert_nodes import remove_assert_nodes
1314
from .remove_detach import remove_detach
@@ -22,6 +23,7 @@
2223
repair_input_as_output,
2324
fuse_prims_broadcast,
2425
replace_max_pool_with_indices,
26+
index_put_replace_bool_with_indices,
2527
remove_assert_nodes,
2628
remove_num_users_is_0_nodes,
2729
complex_graph_detection,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import logging
2+
import operator
3+
4+
import torch
5+
import torch.fx as fx
6+
from torch_tensorrt.dynamo._settings import CompilationSettings
7+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
8+
clean_up_graph_after_modifications,
9+
)
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def _bool_tensor_to_long_indices(
15+
graph: fx.Graph, mask_node: fx.Node, before: fx.Node
16+
) -> fx.Node:
17+
18+
with graph.inserting_before(before):
19+
nz_tuple = graph.call_function(
20+
torch.nonzero, args=(mask_node,), kwargs={"as_tuple": True}
21+
)
22+
idx = graph.call_function(operator.getitem, args=(nz_tuple, 0))
23+
24+
return idx
25+
26+
27+
def index_put_replace_bool_with_indices(
28+
gm: fx.GraphModule, settings: CompilationSettings
29+
) -> fx.GraphModule:
30+
31+
graph = gm.graph
32+
modified_graph = False
33+
for node in list(graph.nodes):
34+
if node.target != torch.ops.aten.index_put.default:
35+
continue
36+
37+
indices = node.args[1]
38+
if isinstance(indices, (list, tuple)):
39+
new_elems = []
40+
for it in indices:
41+
if isinstance(it, fx.Node) and it.meta["val"].dtype == torch.bool:
42+
# bool Tensor → long indices Tensor
43+
idx = _bool_tensor_to_long_indices(graph, it, before=node)
44+
new_elems.append(idx)
45+
elif isinstance(it, (list, tuple)) and all(
46+
isinstance(b, bool) for b in it
47+
):
48+
new_elems.append([i for i, b in enumerate(it) if b])
49+
else:
50+
new_elems.append(it)
51+
new_indices = type(indices)(new_elems)
52+
node.args = (node.args[0], new_indices, *node.args[2:])
53+
elif isinstance(indices, fx.Node) and it.meta["val"].dtype == torch.bool:
54+
idx = _bool_tensor_to_long_indices(graph, indices, before=node)
55+
node.args = (node.args[0], idx, *node.args[2:])
56+
modified_graph = True
57+
58+
if modified_graph:
59+
gm = clean_up_graph_after_modifications(gm)
60+
61+
return gm

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,50 @@ def forward(self, x, y, z, a, b):
328328
result = trt_mod(*inputs)
329329
assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)
330330

331+
def test_bool_mask_test(self):
332+
333+
source_tensor = torch.ones([5, 10], dtype=torch.float32).cuda()
334+
indices_tensor = torch.tensor([False, False, True, False, True])
335+
# indices_tensor = torch.tensor([3,4])
336+
value_tensor = torch.zeros([2, 10], dtype=torch.float32).cuda()
337+
338+
dim1 = torch.export.Dim("dim1", min=1, max=5)
339+
dim2 = torch.export.Dim("dim2", min=1, max=5)
340+
341+
# source_tensor=torch.zeros([5, 5], dtype=torch.int32).cuda()
342+
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32).cuda(), torch.tensor([1, 1], dtype=torch.int32).cuda())
343+
# value_tensor=torch.tensor([1, 2], dtype=torch.int32).cuda()
344+
# accumulate=False
345+
346+
class TestIndexPut(torch.nn.Module):
347+
def forward(self, source_tensor, indices_tensor, value_tensor):
348+
# indices_tensor = torch.where(indices_tensor)[0]
349+
source_tensor[indices_tensor] = value_tensor
350+
return source_tensor
351+
352+
model = TestIndexPut()
353+
torch_output = model.forward(source_tensor, indices_tensor, value_tensor)
354+
355+
ep = torch.export.export(
356+
model,
357+
(source_tensor, indices_tensor, value_tensor),
358+
dynamic_shapes=({0: dim1}, {0: dim1}, {0: dim2}),
359+
)
360+
with torchtrt.dynamo.Debugger(log_level="debug"):
361+
trt_engine = torchtrt.dynamo.compile(
362+
ep,
363+
inputs=(source_tensor, indices_tensor, value_tensor),
364+
enabled_precisions={torch.float32},
365+
min_block_size=1,
366+
use_explicit_typing=False,
367+
use_fp32_acc=False,
368+
disable_tf32=True,
369+
use_python_runtime=True,
370+
)
371+
result = trt_engine(source_tensor, indices_tensor, value_tensor)
372+
373+
torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)
374+
331375

332376
if __name__ == "__main__":
333377
run_tests()

0 commit comments

Comments
 (0)