Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,41 @@ def get_shape_with_dynamic_shape(
select_layer = ctx.net.add_select(condition_val, input_shape, scale_res)
set_layer_name(select_layer, target, f"{name}_select")
return select_layer.get_output(0)


def to_trt_shape_tensor(
ctx: ConversionContext, target: Target, name: str, shape_list: List[int | TRTTensor]
) -> TRTTensor:
"""
Convert a mixed shape list (ints + ITensors) into a single ITensor.

Args:
ctx: ConversionContext
target: fx node target (used for naming).
name (str): base name for layer naming.
shape_list (list[int | ITensor]): list containing static ints and/or ITensors.

Returns:
ITensor if shape_list contains any ITensors, else plain Python list of ints.
"""
trt_tensors = []

for i, s in enumerate(shape_list):
if isinstance(s, int):
const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32))
set_layer_name(const, target, f"{name}_dim{i}_const")
trt_tensors.append(const.get_output(0))
else:
# Assume it's already an ITensor
trt_tensors.append(s)

if trt_tensors:
if any(not isinstance(s, int) for s in shape_list):
# Concatenate everything into a single ITensor
concat_layer = ctx.net.add_concatenation(trt_tensors)
concat_layer.axis = 0
set_layer_name(concat_layer, target, f"{name}_shape_concat")
return concat_layer.get_output(0)

# If no ITensor found, return plain list of ints
return shape_list
12 changes: 9 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.dynamo.conversion.impl.shape import (
get_shape_with_dynamic_shape,
to_trt_shape_tensor,
)


def upsample(
Expand All @@ -28,14 +31,17 @@ def upsample(
if scale_factor is not None:
layer.scales = [1.0, 1.0] + list(scale_factor)
else:
shape = list(input.shape)[:2] + list(size)
shape = list(input.shape)[:2]
if size is not None:
shape += list(size)
if has_dynamic_shape(shape):
shape = get_shape_with_dynamic_shape(
ctx, target, source_ir, name, shape, input
)
layer.set_input(1, shape)
else:
layer.shape = shape
layer.shape = to_trt_shape_tensor(ctx, target, name, shape)
layer.set_input(1, layer.shape)

if mode == "nearest":
layer.resize_mode = trt.InterpolationMode.NEAREST
Expand Down
Loading