diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 5b13ff080..fdceae903 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -139,6 +139,7 @@ def __init__( self.original_network_io.update( {io.name: io.type.tensor_type.elem_type for io in self.model.graph.output} ) + self.network_outputs = {o.name: o for o in self.model.graph.output} self.min_opset = min_opset self.max_ir_version = max_ir_version self.trt_plugins = trt_plugins @@ -1043,6 +1044,12 @@ def _add_cast( name=f"{tensor_name}_cast_to_{cast_to.str_short}", ) + if tensor_name in self.network_outputs: + value_info = helper.make_tensor_value_info(tensor_name, + cast_to.onnx_type, + [d.dim_value for d in self.network_outputs[tensor_name].type.tensor_type.shape.dim]) + self.model.graph.value_info.append(value_info) + if tensor_to_consumers is None: consumer_nodes = utils.get_consumer_nodes(self.model, tensor_name) else: