From 4a4cec1a0cc2738fd939aece6b36b5e8f2aa1941 Mon Sep 17 00:00:00 2001 From: apbose Date: Fri, 26 Sep 2025 18:56:00 -0700 Subject: [PATCH] addresses the case when shape of upsample tensor contains ITensor --- .../dynamo/conversion/impl/shape.py | 38 +++++++++++++++++++ .../dynamo/conversion/impl/upsample.py | 12 ++++-- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index 27af02e5bb..c487dfe598 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -123,3 +123,41 @@ def get_shape_with_dynamic_shape( select_layer = ctx.net.add_select(condition_val, input_shape, scale_res) set_layer_name(select_layer, target, f"{name}_select") return select_layer.get_output(0) + + +def to_trt_shape_tensor( + ctx: ConversionContext, target: Target, name: str, shape_list: List[int | TRTTensor] +) -> TRTTensor: + """ + Convert a mixed shape list (ints + ITensors) into a single ITensor. + + Args: + ctx: ConversionContext + target: fx node target (used for naming). + name (str): base name for layer naming. + shape_list (list[int | ITensor]): list containing static ints and/or ITensors. + + Returns: + ITensor if shape_list contains any ITensors, else plain Python list of ints. + """ + trt_tensors = [] + + for i, s in enumerate(shape_list): + if isinstance(s, int): + const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32)) + set_layer_name(const, target, f"{name}_dim{i}_const") + trt_tensors.append(const.get_output(0)) + else: + # Assume it's already an ITensor + trt_tensors.append(s) + + if trt_tensors: + if any(not isinstance(s, int) for s in shape_list): + # Concatenate everything into a single ITensor + concat_layer = ctx.net.add_concatenation(trt_tensors) + concat_layer.axis = 0 + set_layer_name(concat_layer, target, f"{name}_shape_concat") + return concat_layer.get_output(0) + + # If no ITensor found, return plain list of ints + return shape_list diff --git a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py index 4b47ca5dec..4cedb396d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/upsample.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/upsample.py @@ -9,7 +9,10 @@ has_dynamic_shape, set_layer_name, ) -from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape +from torch_tensorrt.dynamo.conversion.impl.shape import ( + get_shape_with_dynamic_shape, + to_trt_shape_tensor, +) def upsample( @@ -28,14 +31,17 @@ def upsample( if scale_factor is not None: layer.scales = [1.0, 1.0] + list(scale_factor) else: - shape = list(input.shape)[:2] + list(size) + shape = list(input.shape)[:2] + if size is not None: + shape += list(size) if has_dynamic_shape(shape): shape = get_shape_with_dynamic_shape( ctx, target, source_ir, name, shape, input ) layer.set_input(1, shape) else: - layer.shape = shape + layer.shape = to_trt_shape_tensor(ctx, target, name, shape) + layer.set_input(1, layer.shape) if mode == "nearest": layer.resize_mode = trt.InterpolationMode.NEAREST