Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def aten_ops_select(

@dynamo_tensorrt_converter(
torch.ops.aten.index_put.default,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
Expand Down Expand Up @@ -3168,7 +3169,9 @@ def aten_ops_upsample_bicubic2d(


@dynamo_tensorrt_converter(
torch.ops.aten.topk.default, capability_validator=topk_validator
torch.ops.aten.topk.default,
capability_validator=topk_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
Expand Down
249 changes: 149 additions & 100 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add detailed comments demonstrating what it does now that you've gone through the entire converter ? that would be helpful

Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,17 @@ def index(
)
else:
dim_tensor_shape_mult_d1 = transpose_tensor_shape[i]
mult_d1 = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_shape_{i}",
trt.ElementWiseOperation.PROD,
mult_d1,
dim_tensor_shape_mult_d1,
)

if isinstance(dim_tensor_shape_mult_d1, TRTTensor):
mult_d1 = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_shape_{i}",
trt.ElementWiseOperation.PROD,
mult_d1,
dim_tensor_shape_mult_d1,
)

concat_tensor_layer = ctx.net.add_concatenation(
[
Expand Down Expand Up @@ -548,6 +550,9 @@ def index_put_converter(
accumulate: bool = False,
) -> TRTTensor:
# Convert 'input_indices' to TRT tensors (or keep None as is)
input_indices = expand_boolean_indices(
ctx, target, source_ir, name, input_tensor, input_indices
)
indices: List[Optional[Union[TRTTensor, None]]] = []
for i, idx in enumerate(input_indices):
if idx is None:
Expand All @@ -571,22 +576,40 @@ def index_put_converter(
K = len(I)
# Determine the maximum size 'N' among the index tensors
if K > 0:
index_shapes = [tensor.shape[0] for tensor in indices if tensor is not None]
index_shapes = (
[]
) # [tensor.shape[0] for tensor in indices if tensor is not None]
for idx_tensor in indices:
if idx_tensor is not None:
if idx_tensor.shape[0] != DYNAMIC_DIM:
index_shapes.append(idx_tensor.shape[0])
else:
index_shapes.append(
get_shape(
ctx,
target,
source_ir,
name + "idx_shape_dim_0",
idx_tensor,
0,
)
)
N = max(index_shapes) if index_shapes else 1
else:
N = 1

# Compute shapes and volume for the free dimensions
F_shapes = [input_tensor.shape[i] for i in F]
assert -1 not in F_shapes, "Dynamic shape in free dimensions is not supported"
F_volume = trt.volume(F_shapes) if F_shapes else 1

# Process indexed dimensions (I)
I_tensors = []
for i in I:
idx = indices[i]
assert idx is not None
idx_reshaped = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_idx_I_{i}", idx, (idx.shape[0], 1)
idx_reshaped = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, f"{name}_unsqueeze_idx_I_{i}", idx, 1
)
expanded_idx = impl.slice.expand(
ctx,
Expand All @@ -608,46 +631,50 @@ def index_put_converter(
)
arange_tensors.append(arange_tensor)

meshgrid_tensors = []
for i, arange in enumerate(arange_tensors):
reshape_shape = [1] * len(F)
reshape_shape[i] = F_shapes[i]
arange_reshaped = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_arange_F_{F[i]}",
arange,
tuple(reshape_shape),
)
expanded_arange = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_arange_F_{F[i]}",
arange_reshaped,
tuple(F_shapes),
)
meshgrid_tensors.append(expanded_arange)

meshgrid_stacked = impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_stack_meshgrid",
[
impl.shuffle.reshape(
if len(arange_tensors) == 1:
# No need to stack
meshgrid_stacked = arange_tensors[0]
else:
meshgrid_tensors = []
for i, arange in enumerate(arange_tensors):
reshape_shape = [1] * len(F)
reshape_shape[i] = F_shapes[i]
arange_reshaped = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_mesh_{i}",
t,
(*F_shapes, 1),
f"{name}_reshape_arange_F_{F[i]}",
arange,
tuple(reshape_shape),
)
for i, t in enumerate(meshgrid_tensors)
],
dim=-1,
)
expanded_arange = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_arange_F_{F[i]}",
arange_reshaped,
tuple(F_shapes),
)
meshgrid_tensors.append(expanded_arange)

meshgrid_stacked = impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_stack_meshgrid",
[
impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_mesh_{i}",
t,
(*F_shapes, 1),
)
for i, t in enumerate(meshgrid_tensors)
],
dim=-1,
)
meshgrid_reshaped = impl.shuffle.reshape(
ctx,
target,
Expand All @@ -672,21 +699,15 @@ def index_put_converter(

# Combine all indexed dimensions (I)
if K > 0:
I_combined = impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_cat_I",
[
impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1)
)
for i, t in enumerate(I_tensors)
],
dim=2,
)

I_combined = [
impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1)
)
for i, t in enumerate(I_tensors)
]
else:
I_combined = None
I_combined = []

# Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded
ii_list = []
Expand All @@ -695,24 +716,12 @@ def index_put_converter(
for dim in range(rank):
unique_suffix = f"{dim}_{i_idx if dim in I else f_idx}"
if dim in I:
start = [0, 0, i_idx]
shape = [N, F_volume, 1]
stride = [1, 1, 1]
idx_tensor = impl.slice.slice(
ctx,
target,
source_ir,
f"{name}_slice_I_dim_{unique_suffix}",
I_combined,
start,
shape,
stride,
)
idx_tensor = I_combined[i_idx]
ii_list.append(idx_tensor)
i_idx += 1
else:
start = [0, 0, f_idx]
shape = [N, F_volume, 1]
shape = [-1, F_volume, 1] if isinstance(N, TRTTensor) else [N, F_volume, 1]
stride = [1, 1, 1]
mesh_tensor = impl.slice.slice(
ctx,
Expand All @@ -731,20 +740,24 @@ def index_put_converter(
indices_cat = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat_indices", ii_list, dim=2
)

# Flatten the indices_cat to (N * F_volume, rank)
indices_cat = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_indices_cat",
indices_cat,
(N * F_volume, rank),
(-1, rank),
)

if not isinstance(values, TRTTensor):
values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=0)

# Define the expected shape based on (N,) + F_shapes
expected_shape = (N,) + tuple(F_shapes)
expected_shape = (
(-1,) + tuple(F_shapes) if isinstance(N, TRTTensor) else (N,) + tuple(F_shapes)
)

# Broadcast 'values' to match the expected shape
if len(values.shape) == 0 or values.shape == (1,): # Scalar case
Expand All @@ -761,7 +774,12 @@ def index_put_converter(
)
else: # Non-scalar case
values_shape = list(values.shape)
if K > 0 and N in values_shape:
if (
K > 0
and N in values_shape
and (len(F) > 1 and max(F) - min(F) + 1 == len(F))
):
# Continuous case
n_idx = values_shape.index(N)
permute_order = [n_idx] + [
i for i in range(len(values_shape)) if i != n_idx
Expand Down Expand Up @@ -807,31 +825,27 @@ def index_put_converter(
tuple(broadcast_shape),
)
else:
# Discontinuous case
values_shape_padded = [1] * (
len(expected_shape) - len(values.shape)
) + list(values.shape)
broadcast_shape = []
for exp_dim, val_dim in zip(expected_shape, values_shape_padded):
if val_dim == 1 or exp_dim == val_dim:
if val_dim == DYNAMIC_DIM or exp_dim == DYNAMIC_DIM:
broadcast_shape.append(-1)
elif val_dim == 1 or exp_dim == val_dim:
broadcast_shape.append(exp_dim)
else:
raise ValueError(
f"Cannot broadcast {values.shape} to {expected_shape}"
)
values_reshaped = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_values",
values,
tuple(broadcast_shape),
)

values_expanded = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_values",
values_reshaped,
values,
expected_shape,
)

Expand All @@ -842,16 +856,51 @@ def index_put_converter(
source_ir,
f"{name}_flatten_values",
values_expanded,
(N * F_volume,),
(-1,),
)

indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32")
# Perform Scatter ND operation
scatter_layer = ctx.net.add_scatter(
input_tensor,
indices_cat,
flattened_values,
trt.ScatterMode.ND if not accumulate else trt.ScatterMode.ND_ELEMENTWISE_ADD,
)
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
return scatter_layer.get_output(0)
if accumulate:
zero_tensor = impl.full.full(
ctx,
target,
source_ir,
f"{name}_zero_tensor",
[
get_shape(
ctx,
target,
source_ir,
name + f"input_tensor_shape_dim_{i}",
input_tensor,
i,
)
for i in range(len(input_tensor.shape))
],
0.0,
dtype=input_tensor.dtype,
)
# Perform Scatter ND operation
scatter_layer = ctx.net.add_scatter(
zero_tensor,
indices_cat,
flattened_values,
trt.ScatterMode.ND,
)
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)

scatter_out = scatter_layer.get_output(0)
result = impl.elementwise.add(
ctx, target, source_ir, f"{name}_add", scatter_out, input_tensor
)
return result

else:
scatter_layer = ctx.net.add_scatter(
input_tensor,
indices_cat,
flattened_values,
trt.ScatterMode.ND,
)
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
scatter_out = scatter_layer.get_output(0)
return scatter_out
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def remove_num_users_is_0_nodes(
and len(node.all_input_nodes) > 0
):
gm.graph.erase_node(node)
gm = clean_up_graph_after_modifications(gm)

gm = clean_up_graph_after_modifications(gm)

logger.debug(f"Removed ops that [num_users=0] nodes:\n{gm.graph}")

Expand Down
Loading
Loading