From 12cd2813564990cff8282ef435a0d00424e2eaa8 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 2 Sep 2024 05:56:08 -0700 Subject: [PATCH 1/6] Add ONNX Exporter and Enable Engine build from ONNX --- __init__.py | 14 ++- onnx_nodes.py | 44 +++++++ onnx_utils/export.py | 294 +++++++++++++++++++++++++++++++++++++++++++ tensorrt_convert.py | 185 ++++++--------------------- tensorrt_loader.py | 24 ++-- 5 files changed, 400 insertions(+), 161 deletions(-) create mode 100644 onnx_nodes.py create mode 100644 onnx_utils/export.py diff --git a/__init__.py b/__init__.py index 468b3d5..5fc0e23 100644 --- a/__init__.py +++ b/__init__.py @@ -1,11 +1,13 @@ -from .tensorrt_convert import DYNAMIC_TRT_MODEL_CONVERSION -from .tensorrt_convert import STATIC_TRT_MODEL_CONVERSION -from .tensorrt_loader import TrTUnet -from .tensorrt_loader import TensorRTLoader +from .tensorrt_convert import NODE_CLASS_MAPPINGS as CONVERT_CLASS_MAP +from .tensorrt_convert import NODE_DISPLAY_NAME_MAPPINGS as CONVERT_NAME_MAP -NODE_CLASS_MAPPINGS = { "DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION, "STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION, "TensorRTLoader": TensorRTLoader } +from .tensorrt_loader import NODE_CLASS_MAPPINGS as LOADER_CLASS_MAP +from .tensorrt_loader import NODE_DISPLAY_NAME_MAPPINGS as LOADER_NAME_MAP +from .onnx_nodes import NODE_CLASS_MAPPING as ONNX_CLASS_MAP +from .onnx_nodes import NODE_DISPLAY_NAME_MAPPINGS as ONNX_NAME_MAP -NODE_DISPLAY_NAME_MAPPINGS = { "DYNAMIC_TRT_MODEL_CONVERSION": "DYNAMIC TRT_MODEL CONVERSION", "STATIC TRT_MODEL CONVERSION": STATIC_TRT_MODEL_CONVERSION, "TensorRTLoader": "TensorRT Loader" } +NODE_CLASS_MAPPINGS = CONVERT_CLASS_MAP | LOADER_CLASS_MAP | ONNX_CLASS_MAP +NODE_DISPLAY_NAME_MAPPINGS = CONVERT_NAME_MAP | LOADER_NAME_MAP | ONNX_NAME_MAP __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] \ No newline at end of file diff --git a/onnx_nodes.py b/onnx_nodes.py new file mode 100644 index 0000000..b983c69 --- /dev/null +++ b/onnx_nodes.py @@ -0,0 +1,44 @@ +import os +from .onnx_utils.export import export_onnx +import comfy + +class ONNX_EXPORT: + def __init__(self) -> None: + pass + + RETURN_TYPES = () + FUNCTION = "export" + OUTPUT_NODE = True + CATEGORY = "TensorRT" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "output_folder": ("STRING",) + }, + "optional": { + "filename": ("STRING", {"default": "model.onnx"}) + } + } + + def export(self, model, output_folder, filename): + comfy.model_management.unload_all_models() + comfy.model_management.load_models_gpu([model], force_patch_weights=True) + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + path = os.path.join(output_folder, filename) + export_onnx(model, path) + print(f"INFO: Exported Model to: {path}") + return () + + +NODE_CLASS_MAPPING = { + "ONNX_EXPORT": ONNX_EXPORT, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ONNX_EXPORT": "ONNX Export", +} \ No newline at end of file diff --git a/onnx_utils/export.py b/onnx_utils/export.py new file mode 100644 index 0000000..4994f17 --- /dev/null +++ b/onnx_utils/export.py @@ -0,0 +1,294 @@ +import onnx +import torch +from enum import Enum +import comfy +import os +from typing import List +from onnx.external_data_helper import _get_all_tensors, ExternalDataInfo + + +def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]: + """ + Gets the paths of the external data tensors in the model. + Note: make sure you load the model with load_external_data=False. + """ + model_tensors = _get_all_tensors(model) + model_tensors_ext = [ + ExternalDataInfo(tensor).location + for tensor in model_tensors + if tensor.HasField("data_location") + and tensor.data_location == onnx.TensorProto.EXTERNAL + ] + return model_tensors_ext + + +class ModelType(Enum): + SD1x = "SD1.x" + SD2x768v = "SD2.x-768v" + SDXL_BASE = "SDXL-Base" + SDXL_REFINER = "SDXL-Refiner" + SVD = "SVD" + SD3 = "SD3" + AuraFlow = "AuraFlow" + FLUX_DEV = "FLUX-Dev" + FLUX_SCHNELL = "FLUX-Schnell" + UNKNOWN = "Unknown" + + def __eq__(self, value: object) -> bool: + return self.value == value + + @classmethod + def detect_version(cls, model): + if isinstance(model.model, comfy.model_base.SD3): + return cls.SD3 + elif isinstance(model.model, comfy.model_base.AuraFlow): + return cls.AuraFlow + elif isinstance(model.model, comfy.model_base.Flux): + if model.unet_config.guidance_embed: + return cls.FLUX_DEV + else: + return cls.FLUX_SCHNELL + + if model.model.model_config.unet_config.get("use_temporal_resblock", False): + return cls.SVD + + context_dim = model.model.model_config.unet_config.get("context_dim", None) + y_dim = model.model.adm_channels + + if context_dim == 768: + return cls.SD1x + elif context_dim == 1024: + return cls.SD2x768v + elif context_dim == 2048: + if y_dim == 2560: + return cls.SDXL_REFINER + elif y_dim == 2816: + return cls.SDXL_BASE + + return cls.UNKNOWN + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + @classmethod + def list_mo_support(cls): + return [ + cls.SD1x, + cls.SD2x768v, + cls.SDXL_BASE, + cls.SDXL_REFINER, + cls.SD3, + cls.FLUX_DEV, + ] + + +def get_io_names(y_dim: int = None, extra_input: dict = {}): + input_names = ["x", "timesteps", "context"] + output_names = ["h"] + dynamic_axes = { + "x": {0: "batch", 2: "height", 3: "width"}, + "timesteps": {0: "batch"}, + "context": {0: "batch", 1: "num_embeds"}, + } + + if y_dim: + input_names.append("y") + dynamic_axes["y"] = {0: "batch"} + + for k in extra_input: + input_names.append(k) + dynamic_axes[k] = {0: "batch"} + + return input_names, output_names, dynamic_axes + + +def get_shape( + model, + model_type: ModelType, + batch_size: int, + width: int, + height: int, + context_multiplier: int = 1, + num_video_frames: int = 12, + y_dim: int = None, + extra_input: dict = {}, # TODO batch_size*=2? +): + context_len = 77 + context_dim = model.model.model_config.unet_config.get("context_dim", None) + if model_type in (ModelType.AuraFlow, ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL): + context_len = 256 + context_dim = 2048 + elif model_type == ModelType.SD3: + context_embedder_config = model.model.model_config.unet_config.get( + "context_embedder_config", None + ) + if context_embedder_config is not None: + context_dim = context_embedder_config.get("params", {}).get( + "in_features", None + ) + context_len = 154 # NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77 + elif model_type == ModelType.SVD: + batch_size = batch_size * num_video_frames + context_len = 1 + + assert context_dim is not None + + input_channels = model.model.model_config.unet_config.get("in_channels", 4) + inputs_shapes = ( + (batch_size, input_channels, height // 8, width // 8), + (batch_size,), + (batch_size, context_len * context_multiplier, context_dim), + ) + if y_dim > 0: + inputs_shapes += ((batch_size, y_dim),) + + for k in extra_input: + inputs_shapes += ((batch_size,) + extra_input[k],) + + return inputs_shapes + + +def get_sample_input(input_shapes: tuple, dtype: torch.dtype, device: torch.device): + inputs = () + for shape in input_shapes: + inputs += ( + torch.zeros( + shape, + device=device, + dtype=dtype, + ), + ) + + return inputs + + +def get_dtype(model_type: ModelType): + if model_type in (ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL): + return torch.bfloat16 + return torch.float16 + + +def get_backbone(model, model_type, input_names, num_video_frames): + unet = model.model.diffusion_model + transformer_options = model.model_options["transformer_options"].copy() + + if model_type == ModelType.SVD: + + class UNET(torch.nn.Module): + def forward(self, x, timesteps, context, y): + return self.unet( + x, + timesteps, + context, + y, + num_video_frames=self.num_video_frames, + transformer_options=self.transformer_options, + ) + + svd_unet = UNET() + svd_unet.num_video_frames = num_video_frames + svd_unet.unet = unet + svd_unet.transformer_options = transformer_options + unet = svd_unet + else: + + class UNET(torch.nn.Module): + def forward(self, x, timesteps, context, *args): + extras = input_names[3:] + extra_args = {} + for i in range(len(extras)): + extra_args[extras[i]] = args[i] + return self.unet( + x, + timesteps, + context, + transformer_options=self.transformer_options, + **extra_args, + ) + + _unet = UNET() + _unet.unet = unet + _unet.transformer_options = transformer_options + unet = _unet + + return unet + + +def get_extra_input(model, model_type): + y_dim = model.model.adm_channels + extra_input = {} + if model_type in (ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL): + y_dim = model.model.model_config.unet_config.get("vec_in_dim", None) + extra_input = {"guidance": ()} + return y_dim, extra_input + + +def get_io_names_onnx(model: onnx.ModelProto): + input_names = [i.name for i in model.graph.input] + return input_names, None, None + + +def export_onnx( + model, + path, + batch_size: int = 1, + height: int = 512, + width: int = 512, + num_video_frames: int = 12, + context_multiplier: int = 1, +): + model_type = ModelType.detect_version(model) + if model_type == ModelType.UNKNOWN: + raise Exception("ERROR: model not supported.") + + y_dim, extra_input = get_extra_input(model, model_type) + input_names, output_names, dynamic_axes = get_io_names(y_dim, extra_input) + dtype = get_dtype(model_type) + device = comfy.model_management.get_torch_device() + input_shapes = get_shape( + model, + model_type, + batch_size, + width, + height, + context_multiplier, + num_video_frames, + y_dim, + extra_input, + ) + inputs = get_sample_input(input_shapes, dtype, device) + backbone = get_backbone(model, model_type, input_names, num_video_frames) + + torch.onnx.export( + backbone, + inputs, + path, + verbose=False, + input_names=input_names, + output_names=output_names, + opset_version=19, + dynamic_axes=dynamic_axes, + ) + + comfy.model_management.unload_all_models() + comfy.model_management.soft_empty_cache() + dir, name = os.path.split(path) + onnx_model = onnx.load(path, load_external_data=False) + tensors_paths = _get_onnx_external_data_tensors(onnx_model) + + if not tensors_paths: + return + + onnx_model = onnx.load(path, load_external_data=True) + for tensor in tensors_paths: + os.remove(os.path.join(dir, tensor)) + + onnx.save( + onnx_model, + path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=name + "_data", + size_threshold=1024, + ) diff --git a/tensorrt_convert.py b/tensorrt_convert.py index 5774aa3..bba5f90 100644 --- a/tensorrt_convert.py +++ b/tensorrt_convert.py @@ -1,5 +1,3 @@ -import torch -import sys import os import time import comfy.model_management @@ -7,6 +5,13 @@ import tensorrt as trt import folder_paths from tqdm import tqdm +from .onnx_utils.export import ( + get_io_names, + get_extra_input, + get_shape, + export_onnx, + ModelType + ) # TODO: # Make it more generic: less model specific code @@ -147,148 +152,26 @@ def _convert( context_max, num_video_frames, is_static: bool, + output_onnx: bool = False, ): - output_onnx = os.path.normpath( - os.path.join( - os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx" - ) - ) - - comfy.model_management.unload_all_models() - comfy.model_management.load_models_gpu([model], force_patch_weights=True) - unet = model.model.diffusion_model - - context_dim = model.model.model_config.unet_config.get("context_dim", None) - context_len = 77 - context_len_min = context_len - y_dim = model.model.adm_channels - extra_input = {} - dtype = torch.float16 - - if isinstance(model.model, comfy.model_base.SD3): #SD3 - context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None) - if context_embedder_config is not None: - context_dim = context_embedder_config.get("params", {}).get("in_features", None) - context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77 - elif isinstance(model.model, comfy.model_base.AuraFlow): - context_dim = 2048 - context_len_min = 256 - context_len = 256 - elif isinstance(model.model, comfy.model_base.Flux): - context_dim = model.model.model_config.unet_config.get("context_in_dim", None) - context_len_min = 256 - context_len = 256 - y_dim = model.model.model_config.unet_config.get("vec_in_dim", None) - extra_input = {"guidance": ()} - dtype = torch.bfloat16 - - if context_dim is not None: - input_names = ["x", "timesteps", "context"] - output_names = ["h"] - - dynamic_axes = { - "x": {0: "batch", 2: "height", 3: "width"}, - "timesteps": {0: "batch"}, - "context": {0: "batch", 1: "num_embeds"}, - } - - transformer_options = model.model_options['transformer_options'].copy() - if model.model.model_config.unet_config.get( - "use_temporal_resblock", False - ): # SVD - batch_size_min = num_video_frames * batch_size_min - batch_size_opt = num_video_frames * batch_size_opt - batch_size_max = num_video_frames * batch_size_max - - class UNET(torch.nn.Module): - def forward(self, x, timesteps, context, y): - return self.unet( - x, - timesteps, - context, - y, - num_video_frames=self.num_video_frames, - transformer_options=self.transformer_options, - ) - - svd_unet = UNET() - svd_unet.num_video_frames = num_video_frames - svd_unet.unet = unet - svd_unet.transformer_options = transformer_options - unet = svd_unet - context_len_min = context_len = 1 - else: - class UNET(torch.nn.Module): - def forward(self, x, timesteps, context, *args): - extras = input_names[3:] - extra_args = {} - for i in range(len(extras)): - extra_args[extras[i]] = args[i] - return self.unet(x, timesteps, context, transformer_options=self.transformer_options, **extra_args) - - _unet = UNET() - _unet.unet = unet - _unet.transformer_options = transformer_options - unet = _unet - - input_channels = model.model.model_config.unet_config.get("in_channels", 4) - - inputs_shapes_min = ( - (batch_size_min, input_channels, height_min // 8, width_min // 8), - (batch_size_min,), - (batch_size_min, context_len_min * context_min, context_dim), - ) - inputs_shapes_opt = ( - (batch_size_opt, input_channels, height_opt // 8, width_opt // 8), - (batch_size_opt,), - (batch_size_opt, context_len * context_opt, context_dim), - ) - inputs_shapes_max = ( - (batch_size_max, input_channels, height_max // 8, width_max // 8), - (batch_size_max,), - (batch_size_max, context_len * context_max, context_dim), - ) - - if y_dim > 0: - input_names.append("y") - dynamic_axes["y"] = {0: "batch"} - inputs_shapes_min += ((batch_size_min, y_dim),) - inputs_shapes_opt += ((batch_size_opt, y_dim),) - inputs_shapes_max += ((batch_size_max, y_dim),) - - for k in extra_input: - input_names.append(k) - dynamic_axes[k] = {0: "batch"} - inputs_shapes_min += ((batch_size_min,) + extra_input[k],) - inputs_shapes_opt += ((batch_size_opt,) + extra_input[k],) - inputs_shapes_max += ((batch_size_max,) + extra_input[k],) - - - inputs = () - for shape in inputs_shapes_opt: - inputs += ( - torch.zeros( - shape, - device=comfy.model_management.get_torch_device(), - dtype=dtype, - ), + if not output_onnx: + output_onnx = os.path.normpath( + os.path.join( + os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx" ) + ) + os.makedirs(os.path.dirname(output_onnx), exist_ok=True) - else: - print("ERROR: model not supported.") - return () + comfy.model_management.unload_all_models() + comfy.model_management.load_models_gpu([model], force_patch_weights=True) + export_onnx(model, output_onnx) - os.makedirs(os.path.dirname(output_onnx), exist_ok=True) - torch.onnx.export( - unet, - inputs, - output_onnx, - verbose=False, - input_names=input_names, - output_names=output_names, - opset_version=17, - dynamic_axes=dynamic_axes, - ) + model_type = ModelType.detect_version(model) + y_dim, extra_input = get_extra_input(model, model_type) + input_names, _, _ = get_io_names(y_dim, extra_input) + inputs_shapes_min = get_shape(model, model_type, batch_size_min, width_min, height_min, context_min, num_video_frames, y_dim, extra_input) + inputs_shapes_opt = get_shape(model, model_type, batch_size_opt, width_opt, height_opt, context_opt, num_video_frames, y_dim, extra_input) + inputs_shapes_max = get_shape(model, model_type, batch_size_max, width_max, height_max, context_max, num_video_frames, y_dim, extra_input) comfy.model_management.unload_all_models() comfy.model_management.soft_empty_cache() @@ -327,11 +210,7 @@ def forward(self, x, timesteps, context, *args): input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape) ) - if dtype == torch.float16: - config.set_flag(trt.BuilderFlag.FP16) - if dtype == torch.bfloat16: - config.set_flag(trt.BuilderFlag.BF16) - + config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) config.add_optimization_profile(profile) if is_static: @@ -372,6 +251,8 @@ def forward(self, x, timesteps, context, *args): ) serialized_engine = builder.build_serialized_network(network, config) + if serialized_engine is None: + raise Exception("Failed to build Engine") full_output_folder, filename, counter, subfolder, filename_prefix = ( folder_paths.get_save_image_path(filename_prefix, self.output_dir) @@ -516,6 +397,9 @@ def INPUT_TYPES(s): }, ), }, + "optional": { + "onnx_model_path": ("STRING", {"default": "", "forceInput": True}), + } } def convert( @@ -535,6 +419,7 @@ def convert( context_opt, context_max, num_video_frames, + onnx_model_path, ): return super()._convert( model, @@ -553,6 +438,7 @@ def convert( context_max, num_video_frames, is_static=False, + output_onnx=onnx_model_path, ) @@ -612,6 +498,9 @@ def INPUT_TYPES(s): }, ), }, + "optional": { + "onnx_model_path": ("STRING", {"default": "", "forceInput": True}), + } } def convert( @@ -623,6 +512,7 @@ def convert( width_opt, context_opt, num_video_frames, + onnx_model_path, ): return super()._convert( model, @@ -641,6 +531,7 @@ def convert( context_opt, num_video_frames, is_static=True, + output_onnx=onnx_model_path, ) @@ -648,3 +539,7 @@ def convert( "DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION, "STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION, } +NODE_DISPLAY_NAME_MAPPINGS = { + "DYNAMIC_TRT_MODEL_CONVERSION": "DYNAMIC TRT_MODEL CONVERSION", + "STATIC TRT_MODEL CONVERSION": "STATIC_TRT_MODEL_CONVERSION", +} diff --git a/tensorrt_loader.py b/tensorrt_loader.py index 5e2ccac..37e9a40 100644 --- a/tensorrt_loader.py +++ b/tensorrt_loader.py @@ -8,6 +8,7 @@ import comfy.model_patcher import comfy.supported_models import folder_paths +from .onnx_utils.export import ModelType if "tensorrt" in folder_paths.folder_names_and_paths: folder_paths.folder_names_and_paths["tensorrt"][0].append( @@ -114,7 +115,7 @@ class TensorRTLoader: @classmethod def INPUT_TYPES(s): return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ), - "model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux_dev", "flux_schnell"], ), + "model_type": (ModelType.list(), ), }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" @@ -125,41 +126,41 @@ def load_unet(self, unet_name, model_type): if not os.path.isfile(unet_path): raise FileNotFoundError(f"File {unet_path} does not exist") unet = TrTUnet(unet_path) - if model_type == "sdxl_base": + if model_type == ModelType.SDXL_BASE: conf = comfy.supported_models.SDXL({"adm_in_channels": 2816}) conf.unet_config["disable_unet_model_creation"] = True model = comfy.model_base.SDXL(conf) - elif model_type == "sdxl_refiner": + elif model_type == ModelType.SDXL_REFINER: conf = comfy.supported_models.SDXLRefiner( {"adm_in_channels": 2560}) conf.unet_config["disable_unet_model_creation"] = True model = comfy.model_base.SDXLRefiner(conf) - elif model_type == "sd1.x": + elif model_type == ModelType.SD1x: conf = comfy.supported_models.SD15({}) conf.unet_config["disable_unet_model_creation"] = True model = comfy.model_base.BaseModel(conf) - elif model_type == "sd2.x-768v": + elif model_type == ModelType.SD2x768v: conf = comfy.supported_models.SD20({}) conf.unet_config["disable_unet_model_creation"] = True model = comfy.model_base.BaseModel(conf, model_type=comfy.model_base.ModelType.V_PREDICTION) - elif model_type == "svd": + elif model_type == ModelType.SVD: conf = comfy.supported_models.SVD_img2vid({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) - elif model_type == "sd3": + elif model_type == ModelType.SD3: conf = comfy.supported_models.SD3({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) - elif model_type == "auraflow": + elif model_type == ModelType.AuraFlow: conf = comfy.supported_models.AuraFlow({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) - elif model_type == "flux_dev": + elif model_type == ModelType.FLUX_DEV: conf = comfy.supported_models.Flux({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) unet.dtype = torch.bfloat16 #TODO: autodetect - elif model_type == "flux_schnell": + elif model_type == ModelType.FLUX_SCHNELL: conf = comfy.supported_models.FluxSchnell({}) conf.unet_config["disable_unet_model_creation"] = True model = conf.get_model({}) @@ -174,3 +175,6 @@ def load_unet(self, unet_name, model_type): NODE_CLASS_MAPPINGS = { "TensorRTLoader": TensorRTLoader, } +NODE_DISPLAY_NAME_MAPPINGS = { + "TensorRTLoader": "TensorRT Loader" +} \ No newline at end of file From 4365078210bc382bbe4ecdd87c8309e0775dc9c3 Mon Sep 17 00:00:00 2001 From: lspindler Date: Thu, 19 Sep 2024 07:50:07 -0700 Subject: [PATCH 2/6] change default pathes --- onnx_nodes.py | 43 ++++++++++++++++++++++++++++++++++++------- onnx_utils/export.py | 26 +++++++++++++++++--------- requirements.txt | 6 +++++- 3 files changed, 58 insertions(+), 17 deletions(-) diff --git a/onnx_nodes.py b/onnx_nodes.py index b983c69..cc9ee50 100644 --- a/onnx_nodes.py +++ b/onnx_nodes.py @@ -1,6 +1,8 @@ import os from .onnx_utils.export import export_onnx -import comfy +import comfy +import folder_paths + class ONNX_EXPORT: def __init__(self) -> None: @@ -16,13 +18,14 @@ def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), - "output_folder": ("STRING",) + "output_folder": ( + "STRING", + {"default": os.path.join(folder_paths.models_dir, "onnx")}, + ), }, - "optional": { - "filename": ("STRING", {"default": "model.onnx"}) - } + "optional": {"filename": ("STRING", {"default": "model.onnx"})}, } - + def export(self, model, output_folder, filename): comfy.model_management.unload_all_models() comfy.model_management.load_models_gpu([model], force_patch_weights=True) @@ -35,10 +38,36 @@ def export(self, model, output_folder, filename): return () +class ONNXModelSelector: + @classmethod + def INPUT_TYPES(s): + onnx_path = os.path.join(folder_paths.models_dir, "onnx") + if not os.path.exists(onnx_path): + os.makedirs(onnx_path) + onnx_models = [f for f in os.listdir(onnx_path) if f.endswith(".onnx")] + return { + "required": { + "model_name": (onnx_models,), + }, + } + + RETURN_TYPES = ("STRING", "STRING") + RETURN_NAMES = ("model_path", "model_name") + FUNCTION = "select_onnx_model" + CATEGORY = "TensorRT" + + def select_onnx_model(self, model_name): + onnx_path = os.path.join(folder_paths.models_dir, "onnx") + model_path = os.path.join(onnx_path, model_name) + return (model_path, model_name) + + NODE_CLASS_MAPPING = { "ONNX_EXPORT": ONNX_EXPORT, + "ONNXModelSelector": ONNXModelSelector, } NODE_DISPLAY_NAME_MAPPINGS = { "ONNX_EXPORT": "ONNX Export", -} \ No newline at end of file + "ONNXModelSelector": "Select ONNX Model", +} diff --git a/onnx_utils/export.py b/onnx_utils/export.py index 4994f17..49d4d40 100644 --- a/onnx_utils/export.py +++ b/onnx_utils/export.py @@ -5,6 +5,8 @@ import os from typing import List from onnx.external_data_helper import _get_all_tensors, ExternalDataInfo +import folder_paths +import time def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]: @@ -260,10 +262,19 @@ def export_onnx( inputs = get_sample_input(input_shapes, dtype, device) backbone = get_backbone(model, model_type, input_names, num_video_frames) + dir, name = os.path.split(path) + temp_path = os.path.join(folder_paths.get_temp_directory(), "{}".format(time.time())) + onnx_temp = os.path.normpath( + os.path.join(temp_path, name) + ) + + if not os.path.exists(temp_path): + os.makedirs(temp_path) + torch.onnx.export( backbone, inputs, - path, + onnx_temp, verbose=False, input_names=input_names, output_names=output_names, @@ -273,16 +284,13 @@ def export_onnx( comfy.model_management.unload_all_models() comfy.model_management.soft_empty_cache() - dir, name = os.path.split(path) - onnx_model = onnx.load(path, load_external_data=False) - tensors_paths = _get_onnx_external_data_tensors(onnx_model) - if not tensors_paths: - return + onnx_model = onnx.load(onnx_temp, load_external_data=True) + tensors_paths = _get_onnx_external_data_tensors(onnx_model) - onnx_model = onnx.load(path, load_external_data=True) - for tensor in tensors_paths: - os.remove(os.path.join(dir, tensor)) + if tensors_paths: + for tensor in tensors_paths: + os.remove(os.path.join(onnx_temp, tensor)) onnx.save( onnx_model, diff --git a/requirements.txt b/requirements.txt index cef004a..564aae7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,6 @@ -tensorrt>=10.0.1 +tensorrt>=10.4.0 --extra-index-url https://pypi.nvidia.com onnx!=1.16.2 +onnx-graphsurgeon +onnxmltools +onnxconverter-common +pulp From c3c58644bdcefa9946166179247825624260ccf2 Mon Sep 17 00:00:00 2001 From: lspindler Date: Tue, 22 Oct 2024 05:56:04 -0700 Subject: [PATCH 3/6] bug fixes --- onnx_nodes.py | 2 +- onnx_utils/export.py | 9 ++++++--- tensorrt_convert.py | 3 +-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/onnx_nodes.py b/onnx_nodes.py index cc9ee50..45beb89 100644 --- a/onnx_nodes.py +++ b/onnx_nodes.py @@ -28,7 +28,7 @@ def INPUT_TYPES(s): def export(self, model, output_folder, filename): comfy.model_management.unload_all_models() - comfy.model_management.load_models_gpu([model], force_patch_weights=True) + comfy.model_management.load_models_gpu([model], force_patch_weights=True, force_full_load=True) if not os.path.exists(output_folder): os.makedirs(output_folder) diff --git a/onnx_utils/export.py b/onnx_utils/export.py index 49d4d40..e240cfb 100644 --- a/onnx_utils/export.py +++ b/onnx_utils/export.py @@ -46,7 +46,7 @@ def detect_version(cls, model): elif isinstance(model.model, comfy.model_base.AuraFlow): return cls.AuraFlow elif isinstance(model.model, comfy.model_base.Flux): - if model.unet_config.guidance_embed: + if model.model.model_config.unet_config.get("guidance_embed", False): return cls.FLUX_DEV else: return cls.FLUX_SCHNELL @@ -118,9 +118,12 @@ def get_shape( ): context_len = 77 context_dim = model.model.model_config.unet_config.get("context_dim", None) - if model_type in (ModelType.AuraFlow, ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL): + if model_type == ModelType.AuraFlow: context_len = 256 context_dim = 2048 + elif model_type in (ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL): + context_len = 256 + context_dim = model.model.model_config.unet_config.get("context_in_dim", None) elif model_type == ModelType.SD3: context_embedder_config = model.model.model_config.unet_config.get( "context_embedder_config", None @@ -278,7 +281,7 @@ def export_onnx( verbose=False, input_names=input_names, output_names=output_names, - opset_version=19, + opset_version=17, dynamic_axes=dynamic_axes, ) diff --git a/tensorrt_convert.py b/tensorrt_convert.py index c75764a..95a2e6c 100644 --- a/tensorrt_convert.py +++ b/tensorrt_convert.py @@ -181,7 +181,7 @@ def _convert( builder = trt.Builder(logger) network = builder.create_network( - 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) ) parser = trt.OnnxParser(network, logger) success = parser.parse_from_file(output_onnx) @@ -210,7 +210,6 @@ def _convert( input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape) ) - config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) config.add_optimization_profile(profile) if is_static: From 37d2a9e0cfd6626e4fc91a5ecb6a5d793a645628 Mon Sep 17 00:00:00 2001 From: lspindler Date: Thu, 31 Oct 2024 02:06:09 -0700 Subject: [PATCH 4/6] refactor + controlnet support --- __init__.py | 13 +- models/__init__.py | 9 + models/auraflow.py | 23 ++ models/baseline.py | 102 ++++++ models/flux.py | 129 +++++++ models/sd3.py | 67 ++++ models/sd_unet.py | 420 +++++++++++++++++++++++ models/supported_models.py | 84 +++++ onnx_nodes.py | 4 +- onnx_utils/export.py | 284 ++++++--------- tensorrt_diffusion_model.py | 350 +++++++++++++++++++ tensorrt_loader.py | 180 ---------- tensorrt_convert.py => tensorrt_nodes.py | 231 ++++--------- 13 files changed, 1375 insertions(+), 521 deletions(-) create mode 100644 models/__init__.py create mode 100644 models/auraflow.py create mode 100644 models/baseline.py create mode 100644 models/flux.py create mode 100644 models/sd3.py create mode 100644 models/sd_unet.py create mode 100644 models/supported_models.py create mode 100644 tensorrt_diffusion_model.py delete mode 100644 tensorrt_loader.py rename tensorrt_convert.py => tensorrt_nodes.py (62%) diff --git a/__init__.py b/__init__.py index 5fc0e23..e4a40be 100644 --- a/__init__.py +++ b/__init__.py @@ -1,13 +1,10 @@ -from .tensorrt_convert import NODE_CLASS_MAPPINGS as CONVERT_CLASS_MAP -from .tensorrt_convert import NODE_DISPLAY_NAME_MAPPINGS as CONVERT_NAME_MAP - -from .tensorrt_loader import NODE_CLASS_MAPPINGS as LOADER_CLASS_MAP -from .tensorrt_loader import NODE_DISPLAY_NAME_MAPPINGS as LOADER_NAME_MAP +from .tensorrt_nodes import NODE_CLASS_MAPPINGS as TRT_CLASS_MAP +from .tensorrt_nodes import NODE_DISPLAY_NAME_MAPPINGS as TRT_NAME_MAP from .onnx_nodes import NODE_CLASS_MAPPING as ONNX_CLASS_MAP from .onnx_nodes import NODE_DISPLAY_NAME_MAPPINGS as ONNX_NAME_MAP -NODE_CLASS_MAPPINGS = CONVERT_CLASS_MAP | LOADER_CLASS_MAP | ONNX_CLASS_MAP -NODE_DISPLAY_NAME_MAPPINGS = CONVERT_NAME_MAP | LOADER_NAME_MAP | ONNX_NAME_MAP +NODE_CLASS_MAPPINGS = TRT_CLASS_MAP | ONNX_CLASS_MAP +NODE_DISPLAY_NAME_MAPPINGS = TRT_NAME_MAP | ONNX_NAME_MAP -__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] \ No newline at end of file +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..1e1e84e --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,9 @@ +from .supported_models import ( + supported_models, + unsupported_models, + detect_version_from_model, + get_helper_from_version, + get_helper_from_model, + get_model_from_version, +) +from .baseline import TRTModelUtil diff --git a/models/auraflow.py b/models/auraflow.py new file mode 100644 index 0000000..bb991e1 --- /dev/null +++ b/models/auraflow.py @@ -0,0 +1,23 @@ +from .baseline import TRTModelUtil + + +class AuraFlow_TRT(TRTModelUtil): + def __init__( + self, context_dim=2048, input_channels=4, context_len=256, **kwargs + ) -> None: + super().__init__( + context_dim=context_dim, + input_channels=input_channels, + context_len=context_len, + **kwargs, + ) + self.is_conditional = True + + @classmethod + def from_model(cls, model, **kwargs): + return cls( + context_dim=model.model.model_config.unet_config["cond_seq_dim"], + input_channels=model.model.diffusion_model.out_channels, + use_control=False, + **kwargs, + ) diff --git a/models/baseline.py b/models/baseline.py new file mode 100644 index 0000000..f72f30e --- /dev/null +++ b/models/baseline.py @@ -0,0 +1,102 @@ +import torch + + +class TRTModelUtil: + def __init__( + self, + context_dim: int, + input_channels: int, + context_len: int, + use_control: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.context_dim = context_dim + self.input_channels = input_channels + self.context_len = context_len + self.use_control = use_control + self.is_conditional = False + + self.input_config = { + "x": { + "batch": "{batch_size}", + "input_channels": self.input_channels, + "height": "{height}//8", + "width": "{width}//8", + }, + "timesteps": { + "batch": "{batch_size}", + }, + "context": { + "batch": "{batch_size}", + "context_len": "{context_len}", + "context_dim": self.context_dim, + }, + } + + self.output_config = { + "h": { + "batch": "{batch_size}", + "input_channels": self.input_channels, + "height": "{height}//8", + "width": "{width}//8", + } + } + + def to_dict(self): + return { + "context_dim": self.context_dim, + "input_channels": self.input_channels, + "context_len": self.context_dim, + "use_control": self.use_control, + } + + def get_input_names(self) -> list[str]: + return list(self.input_config.keys()) + + def get_output_names(self) -> list[str]: + return list(self.output_config.keys()) + + def get_dtype(self) -> torch.dtype: + return torch.float16 + + def get_input_shapes(self, **kwargs) -> dict: + inputs_shapes = {} + for io_name, io_config in self.input_config.items(): + _inp = self._eval_shape(io_config, **kwargs) + inputs_shapes[io_name] = _inp + + return inputs_shapes + + def get_input_shapes_by_key(self, key: str, **kwargs) -> tuple[int]: + return self._eval_shape(self.input_config[key], **kwargs) + + def get_dynamic_axes(self, config: dict = {}) -> dict: + dynamic_axes = {} + + if config == {}: + config = self.input_config | self.output_config + for k, v in config.items(): + dyn = {i: ax for i, (ax, s) in enumerate(v.items()) if isinstance(s, str)} + dynamic_axes[k] = dyn + + return dynamic_axes + + def _eval_shape(self, inp, **kwargs) -> tuple[int]: + if "context_len" not in kwargs: + kwargs["context_len"] = self.context_len + shape = [] + for _, v in inp.items(): + _s = v + if isinstance(v, str): + _s = int(eval(v.format(**kwargs))) + shape.append(_s) + return tuple(shape) + + def get_control(self, *args, **kwargs) -> dict: + raise NotImplementedError + + @classmethod + def from_model(cls, model, **kwargs): + raise NotImplementedError diff --git a/models/flux.py b/models/flux.py new file mode 100644 index 0000000..d161c31 --- /dev/null +++ b/models/flux.py @@ -0,0 +1,129 @@ +from .baseline import TRTModelUtil +import torch + + +class FLuxBase(TRTModelUtil): + def __init__( + self, + context_dim: int, + input_channels: int, + y_dim: int, + hidden_size: int, + double_blocks: int, + single_blocks: int, + *args, + **kwargs, + ) -> None: + super().__init__(context_dim, input_channels, 256, *args, **kwargs) + + self.hidden_size = hidden_size + self.y_dim = y_dim + self.single_blocks = single_blocks + self.double_blocks = double_blocks + + self.extra_input = { + "guidance": {"batch": "{batch_size}"}, + "y": {"batch": "{batch_size}", "y_dim": y_dim}, + } + + self.input_config.update(self.extra_input) + + if self.use_control: + self.control = self.get_control(double_blocks, single_blocks) + self.input_config.update(self.control) + + def to_dict(self): + return { + self.__name__: { + "context_dim": self.context_dim, + "input_channels": self.input_channels, + "y_dim": self.y_dim, + "hidden_size": self.hidden_size, + "double_blocks": self.double_blocks, + "single_blocks": self.single_blocks, + "use_control": self.use_control, + } + } + + def get_control(self, double_blocks: int, single_blocks: int): + control_input = {} + for i in range(double_blocks): + control_input[f"input_control_{i}"] = { + "batch": "{batch_size}", + "ids": "({height}*{width}//(8*2)**2)", + "hidden_size": self.hidden_size, + } + for i in range(single_blocks): + control_input[f"output_control_{i}"] = { + "batch": "{batch_size}", + "ids": "({height}*{width}//(8*2)**2)", + "hidden_size": self.hidden_size, + } + return control_input + + def get_dtype(self): + return torch.bfloat16 + + @classmethod + def from_model(cls, model, **kwargs): + return cls( + context_dim=model.model.model_config.unet_config["context_in_dim"], + input_channels=model.model.model_config.unet_config["in_channels"], + hidden_size=model.model.model_config.unet_config["hidden_size"], + y_dim=model.model.model_config.unet_config["vec_in_dim"], + double_blocks=model.model.model_config.unet_config["depth"], + single_blocks=model.model.model_config.unet_config["depth_single_blocks"], + **kwargs, + ) + + +class Flux_TRT(FLuxBase): + def __init__( + self, + context_dim=4096, + input_channels=16, + y_dim=768, + hidden_size=3072, + double_blocks=19, + single_blocks=28, + **kwargs, + ): + super().__init__( + context_dim=context_dim, + input_channels=input_channels, + y_dim=y_dim, + hidden_size=hidden_size, + double_blocks=double_blocks, + single_blocks=single_blocks, + **kwargs, + ) + + @classmethod + def from_model(cls, model): + return super(Flux_TRT, cls).from_model(model, use_control=True) + + +class FluxSchnell_TRT(FLuxBase): + def __init__( + self, + context_dim=4096, + input_channels=16, + y_dim=768, + hidden_size=3072, + double_blocks=19, + single_blocks=28, + **kwargs, + ): + super().__init__( + context_dim=context_dim, + input_channels=input_channels, + y_dim=y_dim, + hidden_size=hidden_size, + double_blocks=double_blocks, + single_blocks=single_blocks, + **kwargs, + ) + + @classmethod + def from_model(cls, model): + return super(FluxSchnell_TRT, cls).from_model(model, use_control=True) diff --git a/models/sd3.py b/models/sd3.py new file mode 100644 index 0000000..11c3111 --- /dev/null +++ b/models/sd3.py @@ -0,0 +1,67 @@ +from .baseline import TRTModelUtil +import torch + + +class SD3_TRT(TRTModelUtil): + def __init__( + self, + context_dim: int = 4096, + input_channels: int = 16, + y_dim: int = 2048, + hidden_size: int = 1536, + output_blocks: int = 24, + *args, + **kwargs, + ) -> None: + super().__init__(context_dim, input_channels, 77, *args, **kwargs) + + self.hidden_size = hidden_size + self.y_dim = y_dim + self.is_conditional = True + self.output_blocks = output_blocks # - 1 # self.joint_blocks + + self.extra_input = { + "y": {"batch": "{batch_size}", "y_dim": y_dim}, + } + + self.input_config.update(self.extra_input) + + if self.use_control: + self.control = self.get_control(output_blocks) + self.input_config.update(self.control) + + def to_dict(self): + return { + "context_dim": self.context_dim, + "input_channels": self.input_channels, + "hidden_size": self.hidden_size, + "y_dim": self.y_dim, + "output_blocks": self.output_blocks, + "use_control": self.use_control, + } + + def get_control(self, output_blocks: int): + control_input = {} + for i in range(output_blocks): + control_input[f"output_control_{i}"] = { + "batch": "{batch_size}", + "ids": "({height}*{width}//(8*2)**2)", + "hidden_size": self.hidden_size, + } + + return control_input + + def get_dtype(self): + return torch.float16 + + @classmethod + def from_model(cls, model, **kwargs): + return cls( + context_dim=model.model.diffusion_model.context_embedder.in_features, + input_channels=model.model.diffusion_model.in_channels, + hidden_size=model.model.diffusion_model.context_embedder.out_features, + y_dim=model.model.model_config.unet_config.get("adm_in_channels", 0), + output_blocks=model.model.diffusion_model.depth, + use_control=True, + **kwargs, + ) diff --git a/models/sd_unet.py b/models/sd_unet.py new file mode 100644 index 0000000..742158b --- /dev/null +++ b/models/sd_unet.py @@ -0,0 +1,420 @@ +from .baseline import TRTModelUtil +import torch + +from comfy.supported_models import ( + SD15, + SD20, + SD21UnclipL, + SD21UnclipH, + SDXLRefiner, + SDXL, + SSD1B, + Segmind_Vega, + KOALA_700M, + KOALA_1B, + SVD_img2vid, + SD15_instructpix2pix, + SDXL_instructpix2pix, +) + + +class UNetTRT(TRTModelUtil): + def __init__( + self, + context_dim: int, + input_channels: int, + y_dim: int, + hidden_size: int, + channel_mult: tuple[int], + num_res_blocks: tuple[int], + context_len: int = 77, + *args, + **kwargs, + ) -> None: + super().__init__(context_dim, input_channels, context_len, *args, **kwargs) + + self.hidden_size = hidden_size + self.y_dim = y_dim + self.is_conditional = True + + self.channel_mult = channel_mult + self.num_res_blocks = num_res_blocks + self.set_block_chans() + + if self.y_dim: + self.input_config.update({"y": {"batch": "{batch_size}", "y_dim": y_dim}}) + + if self.use_control: + self.control = self.get_control() + self.input_config.update(self.control) + + def set_block_chans(self): + ch = self.hidden_size + ds = 1 + + input_block_chans = [(self.hidden_size, ds)] + for level, mult in enumerate(self.channel_mult): + for nr in range(self.num_res_blocks[level]): + ch = mult * self.hidden_size + input_block_chans.append((ch, ds)) + if level != len(self.channel_mult) - 1: + out_ch = ch + ch = out_ch + ds *= 2 + input_block_chans.append((ch, ds)) + self.input_block_chans = input_block_chans + + @classmethod + def from_model(cls, model, **kwargs): + return cls( + context_dim=model.model.model_config.unet_config["context_dim"], + input_channels=model.model.diffusion_model.in_channels, + hidden_size=model.model.model_config.unet_config["model_channels"], + y_dim=model.model.model_config.unet_config.get("adm_in_channels", 0), + channel_mult=model.model.diffusion_model.channel_mult, + num_res_blocks=model.model.diffusion_model.num_res_blocks, + **kwargs, + ) + + def to_dict(self): + return { + "context_dim": self.context_dim, + "input_channels": self.input_channels, + "hidden_size": self.hidden_size, + "y_dim": self.y_dim, + "channel_mult": self.channel_mult, + "num_res_blocks": self.num_res_blocks, + "use_control": self.use_control, + } + + def get_control(self): + control_input = {} + + for i, (ch, d) in enumerate(reversed(self.input_block_chans)): + control_input[f"input_control_{i}"] = { + "batch": "{batch_size}", + "chn": ch, + f"height{d}": "{height}//(8*" + str(d) + ")", + f"width{d}": "{width}//(8*" + str(d) + ")", + } + + for i, (ch, d) in enumerate(self.input_block_chans): + control_input[f"output_control_{i}"] = { + "batch": "{batch_size}", + "chn": ch, + f"height{d}": "{height}//(8*" + str(d) + ")", + f"width{d}": "{width}//(8*" + str(d) + ")", + } + + ch, d = self.input_block_chans[-1] + control_input["middle_control_0"] = { + "batch": "{batch_size}", + "chn": ch, + f"height{d}": "{height}//(8*" + str(d) + ")", + f"width{d}": "{width}//(8*" + str(d) + ")", + } + return control_input + # return {} + + def get_dtype(self): + return torch.float16 + + +class SD15_TRT(UNetTRT): + def __init__( + self, + context_dim=SD15.unet_config["context_dim"], + input_channels=4, + y_dim=0, + hidden_size=SD15.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + @classmethod + def from_model(cls, model): + return super(SD15_TRT, cls).from_model(model, use_control=True) + + +class SD20_TRT(UNetTRT): + def __init__( + self, + context_dim=SD20.unet_config["context_dim"], + input_channels=4, + y_dim=0, + hidden_size=SD20.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + @classmethod + def from_model(cls, model): + return super(SD20_TRT, cls).from_model(model, use_control=True) + + +class SD21UnclipL_TRT(UNetTRT): + def __init__( + self, + context_dim=SD21UnclipL.unet_config["context_dim"], + input_channels=4, + y_dim=SD21UnclipL.unet_config["adm_in_channels"], + hidden_size=SD21UnclipL.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class SD21UnclipH_TRT(UNetTRT): + def __init__( + self, + context_dim=SD21UnclipH.unet_config["context_dim"], + input_channels=4, + y_dim=SD21UnclipH.unet_config["adm_in_channels"], + hidden_size=SD21UnclipH.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class SDXLRefiner_TRT(UNetTRT): + def __init__( + self, + context_dim=SDXLRefiner.unet_config["context_dim"], + input_channels=4, + y_dim=SDXLRefiner.unet_config["adm_in_channels"], + hidden_size=SDXLRefiner.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class SDXL_TRT(UNetTRT): + def __init__( + self, + context_dim=SDXL.unet_config["context_dim"], + input_channels=4, + y_dim=SDXL.unet_config["adm_in_channels"], + hidden_size=SDXL.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + @classmethod + def from_model(cls, model): + return super(SDXL_TRT, cls).from_model(model, use_control=True) + + +class SSD1B_TRT(UNetTRT): + def __init__( + self, + context_dim=SSD1B.unet_config["context_dim"], + input_channels=4, + y_dim=SSD1B.unet_config["adm_in_channels"], + hidden_size=SSD1B.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class Segmind_Vega_TRT(UNetTRT): + def __init__( + self, + context_dim=Segmind_Vega.unet_config["context_dim"], + input_channels=4, + y_dim=Segmind_Vega.unet_config["adm_in_channels"], + hidden_size=Segmind_Vega.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), # TODO + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class KOALA_700M_TRT(UNetTRT): + def __init__( + self, + context_dim=KOALA_700M.unet_config["context_dim"], + input_channels=4, + y_dim=KOALA_700M.unet_config["adm_in_channels"], + hidden_size=KOALA_700M.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), # TODO + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class KOALA_1B_TRT(UNetTRT): + def __init__( + self, + context_dim=KOALA_1B.unet_config["context_dim"], + input_channels=4, + y_dim=KOALA_1B.unet_config["adm_in_channels"], + hidden_size=KOALA_1B.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), # TODO + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class SVD_img2vid_TRT(UNetTRT): + def __init__( + self, + context_dim=SVD_img2vid.unet_config["context_dim"], + input_channels=8, + y_dim=SVD_img2vid.unet_config["adm_in_channels"], + hidden_size=SVD_img2vid.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + self.input_config["context"]["context_len"] = self.context_len + + +class SD15_instructpix2pix_TRT(UNetTRT): + def __init__( + self, + context_dim=SD15_instructpix2pix.unet_config["context_dim"], + input_channels=8, + y_dim=0, + hidden_size=SD15_instructpix2pix.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class SDXL_instructpix2pix_TRT(UNetTRT): + def __init__( + self, + context_dim=SDXL_instructpix2pix.unet_config["context_dim"], + input_channels=8, + y_dim=SDXL_instructpix2pix.unet_config["adm_in_channels"], + hidden_size=SDXL_instructpix2pix.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) diff --git a/models/supported_models.py b/models/supported_models.py new file mode 100644 index 0000000..1b09fa1 --- /dev/null +++ b/models/supported_models.py @@ -0,0 +1,84 @@ +from .baseline import TRTModelUtil +from .flux import Flux_TRT, FluxSchnell_TRT +from .auraflow import AuraFlow_TRT +from .sd3 import SD3_TRT +from .sd_unet import ( + SD15_TRT, + SD20_TRT, + SD21UnclipH_TRT, + SD21UnclipL_TRT, + SDXL_instructpix2pix_TRT, + SDXLRefiner_TRT, + SDXL_TRT, + SSD1B_TRT, + KOALA_700M_TRT, + KOALA_1B_TRT, + Segmind_Vega_TRT, + SVD_img2vid_TRT, +) +import comfy.supported_models + +supported_models = { + "SD15": SD15_TRT, + "SD20": SD20_TRT, + "SD21UnclipL": SD21UnclipL_TRT, + "SD21UnclipH": SD21UnclipH_TRT, + "SDXL_instructpix2pix": SDXL_instructpix2pix_TRT, + "SDXLRefiner": SDXLRefiner_TRT, + "SDXL": SDXL_TRT, + "SSD1B": SSD1B_TRT, + "KOALA_700M": KOALA_700M_TRT, + "KOALA_1B": KOALA_1B_TRT, + "Segmind_Vega": Segmind_Vega_TRT, + "SVD_img2vid": SVD_img2vid_TRT, + "SD3": SD3_TRT, + "AuraFlow": AuraFlow_TRT, + "Flux": Flux_TRT, + "FluxSchnell": FluxSchnell_TRT, +} + +unsupported_models = [ + "SV3D_u", + "SV3D_p", + "Stable_Zero123", + "SD_X4Upscaler", + "Stable_Cascade_C", + "Stable_Cascade_B", + "StableAudio", + "HunyuanDiT", + "HunyuanDiT1", +] + + +def detect_version_from_model(model): + return model.model.model_config.__class__.__name__ + + +def get_helper_from_version(model_version: str, config: dict = {}) -> TRTModelUtil: + model_helper = supported_models.get(model_version, None) + if model_helper is None: + raise NotImplementedError("{} is not supported.".format(model_version)) + return model_helper(**config) + + +def get_helper_from_model(model) -> TRTModelUtil: + model_version = detect_version_from_model(model) + helper_cls = supported_models.get(model_version, None) + if helper_cls is None: + raise NotImplementedError("{} is not supported.".format(model_version)) + return helper_cls.from_model(model) + + +def get_model_from_version(model_version: str, config: dict = {}): + conf = getattr(comfy.supported_models, model_version) + helper = get_helper_from_version(model_version, config) + conf.unet_config["disable_unet_model_creation"] = True + conf.unet_config["in_channels"] = helper.input_channels + conf = conf(conf.unet_config) + if model_version in ("SD20",): + model = comfy.model_base.BaseModel( + conf, model_type=comfy.model_base.ModelType.V_PREDICTION + ) + else: + model = conf.get_model({}) + return model, helper diff --git a/onnx_nodes.py b/onnx_nodes.py index 45beb89..1ae9969 100644 --- a/onnx_nodes.py +++ b/onnx_nodes.py @@ -28,7 +28,9 @@ def INPUT_TYPES(s): def export(self, model, output_folder, filename): comfy.model_management.unload_all_models() - comfy.model_management.load_models_gpu([model], force_patch_weights=True, force_full_load=True) + comfy.model_management.load_models_gpu( + [model], force_patch_weights=True, force_full_load=True + ) if not os.path.exists(output_folder): os.makedirs(output_folder) diff --git a/onnx_utils/export.py b/onnx_utils/export.py index e240cfb..0e4554a 100644 --- a/onnx_utils/export.py +++ b/onnx_utils/export.py @@ -1,12 +1,15 @@ import onnx import torch -from enum import Enum +import numpy as np +from onnx import numpy_helper +import json import comfy import os from typing import List from onnx.external_data_helper import _get_all_tensors, ExternalDataInfo import folder_paths import time +from ..models import detect_version_from_model, get_helper_from_model def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]: @@ -24,161 +27,25 @@ def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]: return model_tensors_ext -class ModelType(Enum): - SD1x = "SD1.x" - SD2x768v = "SD2.x-768v" - SDXL_BASE = "SDXL-Base" - SDXL_REFINER = "SDXL-Refiner" - SVD = "SVD" - SD3 = "SD3" - AuraFlow = "AuraFlow" - FLUX_DEV = "FLUX-Dev" - FLUX_SCHNELL = "FLUX-Schnell" - UNKNOWN = "Unknown" - - def __eq__(self, value: object) -> bool: - return self.value == value - - @classmethod - def detect_version(cls, model): - if isinstance(model.model, comfy.model_base.SD3): - return cls.SD3 - elif isinstance(model.model, comfy.model_base.AuraFlow): - return cls.AuraFlow - elif isinstance(model.model, comfy.model_base.Flux): - if model.model.model_config.unet_config.get("guidance_embed", False): - return cls.FLUX_DEV - else: - return cls.FLUX_SCHNELL - - if model.model.model_config.unet_config.get("use_temporal_resblock", False): - return cls.SVD - - context_dim = model.model.model_config.unet_config.get("context_dim", None) - y_dim = model.model.adm_channels - - if context_dim == 768: - return cls.SD1x - elif context_dim == 1024: - return cls.SD2x768v - elif context_dim == 2048: - if y_dim == 2560: - return cls.SDXL_REFINER - elif y_dim == 2816: - return cls.SDXL_BASE - - return cls.UNKNOWN - - @classmethod - def list(cls): - return list(map(lambda c: c.value, cls)) - - @classmethod - def list_mo_support(cls): - return [ - cls.SD1x, - cls.SD2x768v, - cls.SDXL_BASE, - cls.SDXL_REFINER, - cls.SD3, - cls.FLUX_DEV, - ] - - -def get_io_names(y_dim: int = None, extra_input: dict = {}): - input_names = ["x", "timesteps", "context"] - output_names = ["h"] - dynamic_axes = { - "x": {0: "batch", 2: "height", 3: "width"}, - "timesteps": {0: "batch"}, - "context": {0: "batch", 1: "num_embeds"}, - } - - if y_dim: - input_names.append("y") - dynamic_axes["y"] = {0: "batch"} - - for k in extra_input: - input_names.append(k) - dynamic_axes[k] = {0: "batch"} - - return input_names, output_names, dynamic_axes - - -def get_shape( - model, - model_type: ModelType, - batch_size: int, - width: int, - height: int, - context_multiplier: int = 1, - num_video_frames: int = 12, - y_dim: int = None, - extra_input: dict = {}, # TODO batch_size*=2? -): - context_len = 77 - context_dim = model.model.model_config.unet_config.get("context_dim", None) - if model_type == ModelType.AuraFlow: - context_len = 256 - context_dim = 2048 - elif model_type in (ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL): - context_len = 256 - context_dim = model.model.model_config.unet_config.get("context_in_dim", None) - elif model_type == ModelType.SD3: - context_embedder_config = model.model.model_config.unet_config.get( - "context_embedder_config", None - ) - if context_embedder_config is not None: - context_dim = context_embedder_config.get("params", {}).get( - "in_features", None - ) - context_len = 154 # NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77 - elif model_type == ModelType.SVD: - batch_size = batch_size * num_video_frames - context_len = 1 - - assert context_dim is not None - - input_channels = model.model.model_config.unet_config.get("in_channels", 4) - inputs_shapes = ( - (batch_size, input_channels, height // 8, width // 8), - (batch_size,), - (batch_size, context_len * context_multiplier, context_dim), - ) - if y_dim > 0: - inputs_shapes += ((batch_size, y_dim),) - - for k in extra_input: - inputs_shapes += ((batch_size,) + extra_input[k],) - - return inputs_shapes - - -def get_sample_input(input_shapes: tuple, dtype: torch.dtype, device: torch.device): - inputs = () - for shape in input_shapes: - inputs += ( +def get_sample_input(input_shapes: dict, dtype: torch.dtype, device: torch.device): + inputs = [] + for k, shape in input_shapes.items(): + inputs.append( torch.zeros( shape, device=device, dtype=dtype, - ), + ) ) - return inputs - - -def get_dtype(model_type: ModelType): - if model_type in (ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL): - return torch.bfloat16 - return torch.float16 + return tuple(inputs) -def get_backbone(model, model_type, input_names, num_video_frames): +def get_backbone(model, model_version, input_names, num_video_frames, use_control): unet = model.model.diffusion_model transformer_options = model.model_options["transformer_options"].copy() - if model_type == ModelType.SVD: + if model_version == "SVD_img2vid": class UNET(torch.nn.Module): def forward(self, x, timesteps, context, y): @@ -201,9 +68,20 @@ def forward(self, x, timesteps, context, y): class UNET(torch.nn.Module): def forward(self, x, timesteps, context, *args): extras = input_names[3:] + control = {"input": [], "output": [], "middle": []} extra_args = {} for i in range(len(extras)): - extra_args[extras[i]] = args[i] + if "control" in extras[i]: + if "input" in extras[i]: + control["input"].append(args[i]) + elif "output" in extras[i]: + control["output"].append(args[i]) + elif "middle" in extras[i]: + control["middle"].append(args[i]) + else: + extra_args[extras[i]] = args[i] + if use_control: + extra_args["control"] = control return self.unet( x, timesteps, @@ -220,18 +98,67 @@ def forward(self, x, timesteps, context, *args): return unet -def get_extra_input(model, model_type): - y_dim = model.model.adm_channels - extra_input = {} - if model_type in (ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL): - y_dim = model.model.model_config.unet_config.get("vec_in_dim", None) - extra_input = {"guidance": ()} - return y_dim, extra_input +# Helper utility for weights map +def export_weights_map(state_dict, onnx_opt_path: str, weights_map_path: str): + onnx_opt_dir = onnx_opt_path + onnx_opt_model = onnx.load(onnx_opt_path) + + # Create initializer data hashes + def init_hash_map(onnx_opt_model): + initializer_hash_mapping = {} + for initializer in onnx_opt_model.graph.initializer: + initializer_data = numpy_helper.to_array( + initializer, base_dir=onnx_opt_dir + ).astype(np.float16) + initializer_hash = hash(initializer_data.data.tobytes()) + initializer_hash_mapping[initializer.name] = ( + initializer_hash, + initializer_data.shape, + ) + return initializer_hash_mapping + + initializer_hash_mapping = init_hash_map(onnx_opt_model) + + weights_name_mapping = {} + weights_shape_mapping = {} + # set to keep track of initializers already added to the name_mapping dict + initializers_mapped = set() + for wt_name, wt in state_dict.items(): + # get weight hash + wt = wt.cpu().detach().numpy().astype(np.float16) + wt_hash = hash(wt.data.tobytes()) + wt_t_hash = hash(np.transpose(wt).data.tobytes()) + + for initializer_name, ( + initializer_hash, + initializer_shape, + ) in initializer_hash_mapping.items(): + # Due to constant folding, some weights are transposed during export + # To account for the transpose op, we compare the initializer hash to the + # hash for the weight and its transpose + if wt_hash == initializer_hash or wt_t_hash == initializer_hash: + # The assert below ensures there is a 1:1 mapping between + # PyTorch and ONNX weight names. It can be removed in cases where 1:many + # mapping is found and name_mapping[wt_name] = list() + assert initializer_name not in initializers_mapped + weights_name_mapping[wt_name] = initializer_name + initializers_mapped.add(initializer_name) + is_transpose = False if wt_hash == initializer_hash else True + weights_shape_mapping[wt_name] = ( + initializer_shape, + is_transpose, + ) + # Sanity check: Were any weights not matched + if wt_name not in weights_name_mapping: + print(f"[I] PyTorch weight {wt_name} not matched with any ONNX initializer") + print( + f"[I] UNet: {len(weights_name_mapping.keys())} PyTorch weights were matched with ONNX initializers" + ) -def get_io_names_onnx(model: onnx.ModelProto): - input_names = [i.name for i in model.graph.input] - return input_names, None, None + assert weights_name_mapping.keys() == weights_shape_mapping.keys() + with open(weights_map_path, "w") as fp: + json.dump([weights_name_mapping, weights_shape_mapping], fp) def export_onnx( @@ -240,36 +167,39 @@ def export_onnx( batch_size: int = 1, height: int = 512, width: int = 512, - num_video_frames: int = 12, + num_video_frames: int = 14, context_multiplier: int = 1, + use_control: bool = True, ): - model_type = ModelType.detect_version(model) - if model_type == ModelType.UNKNOWN: - raise Exception("ERROR: model not supported.") + model_version = detect_version_from_model(model) + model_helper = get_helper_from_model(model) - y_dim, extra_input = get_extra_input(model, model_type) - input_names, output_names, dynamic_axes = get_io_names(y_dim, extra_input) - dtype = get_dtype(model_type) + dtype = model_helper.get_dtype() device = comfy.model_management.get_torch_device() - input_shapes = get_shape( - model, - model_type, - batch_size, - width, - height, - context_multiplier, - num_video_frames, - y_dim, - extra_input, + input_names = model_helper.get_input_names() + output_names = model_helper.get_output_names() + dynamic_axes = model_helper.get_dynamic_axes() + + if model_version == "SVD_img2vid": + batch_size *= num_video_frames + if model_helper.is_conditional: + batch_size *= 2 + input_shapes = model_helper.get_input_shapes( + batch_size=batch_size, + height=height, + width=width, + context_multiplier=context_multiplier, ) inputs = get_sample_input(input_shapes, dtype, device) - backbone = get_backbone(model, model_type, input_names, num_video_frames) + backbone = get_backbone( + model, model_version, input_names, num_video_frames, use_control + ) dir, name = os.path.split(path) - temp_path = os.path.join(folder_paths.get_temp_directory(), "{}".format(time.time())) - onnx_temp = os.path.normpath( - os.path.join(temp_path, name) + temp_path = os.path.join( + folder_paths.get_temp_directory(), "{}".format(time.time()) ) + onnx_temp = os.path.normpath(os.path.join(temp_path, name)) if not os.path.exists(temp_path): os.makedirs(temp_path) diff --git a/tensorrt_diffusion_model.py b/tensorrt_diffusion_model.py new file mode 100644 index 0000000..684fb8c --- /dev/null +++ b/tensorrt_diffusion_model.py @@ -0,0 +1,350 @@ +import torch +import tensorrt as trt +import os +from typing import Optional +from tqdm import tqdm +from math import prod +import comfy.model_management +from .models import get_model_from_version, TRTModelUtil + +trt.init_libnvinfer_plugins(None, "") +logger = trt.Logger(trt.Logger.INFO) +runtime = trt.Runtime(logger) + + +def trt_datatype_to_torch(datatype): + if datatype == trt.float16: + return torch.float16 + elif datatype == trt.float32: + return torch.float32 + elif datatype == trt.int32: + return torch.int32 + elif datatype == trt.bfloat16: + return torch.bfloat16 + + +class TRTModel(torch.nn.Module): + def __init__(self, model_helper: TRTModelUtil, *args, **kwargs) -> None: + super(TRTModel, self).__init__() + + self.context = None + self.engine = None + + self.model = model_helper + self.dtype = self.model.get_dtype() + self.device = comfy.model_management.get_torch_device() + + self.input_names = self.model.get_input_names() + self.output_names = self.model.get_output_names() + + self.current_shape: tuple[int] = (0,) + self.output_shapes: dict[str, tuple[int]] = {} + self.curr_split_batch: int = 0 + + self.zero_pool = None + self.extra_inputs: dict[str, torch.Tensor] = {} + + # Sets up the builder to use the timing cache file, and creates it if it does not already exist + def _setup_timing_cache(self, config: trt.IBuilderConfig, timing_cache_path: str): + buffer = b"" + if os.path.exists(timing_cache_path): + with open(timing_cache_path, mode="rb") as timing_cache_file: + buffer = timing_cache_file.read() + print("Read {} bytes from timing cache.".format(len(buffer))) + else: + print("No timing cache found; Initializing a new one.") + timing_cache: trt.ITimingCache = config.create_timing_cache(buffer) + config.set_timing_cache(timing_cache, ignore_mismatch=True) + + # Saves the config's timing cache to file + def _save_timing_cache(self, config: trt.IBuilderConfig, timing_cache_path: str): + timing_cache: trt.ITimingCache = config.get_timing_cache() + with open(timing_cache_path, "wb") as timing_cache_file: + timing_cache_file.write(memoryview(timing_cache.serialize())) + + def _create_profile(self, builder, min_config, opt_config, max_config): + profile = builder.create_optimization_profile() + + min_config = opt_config if min_config is None else min_config + max_config = opt_config if min_config is None else max_config + + min_shapes = self.model.get_input_shapes(**min_config) + opt_shapes = self.model.get_input_shapes(**opt_config) + max_shapes = self.model.get_input_shapes(**max_config) + for input_name in opt_shapes.keys(): + profile.set_shape( + input_name, + min_shapes[input_name], + opt_shapes[input_name], + max_shapes[input_name], + ) + + return profile + + def build( + self, + onnx_path: str, + engine_path: str, + timing_cache_path: str, + opt_config: dict, + min_config: Optional[dict] = None, + max_config: Optional[dict] = None, + ) -> bool: + comfy.model_management.unload_all_models() + comfy.model_management.soft_empty_cache() + + # TRT conversion starts here + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + + network = builder.create_network( + 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + ) + parser = trt.OnnxParser(network, logger) + success = parser.parse_from_file(onnx_path) + for idx in range(parser.num_errors): + print(parser.get_error(idx)) + + if not success: + print("ONNX load ERROR") + return False + + config = builder.create_builder_config() + self._setup_timing_cache(config, timing_cache_path) + config.progress_monitor = TQDMProgressMonitor() + profile = self._create_profile(builder, min_config, opt_config, max_config) + config.add_optimization_profile(profile) + config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) # STRIP_PLAN + + self.engine = builder.build_serialized_network(network, config) + if self.engine is None: + raise Exception("Failed to build Engine") + + model = { + "engine": torch.ByteTensor( + bytearray(self.engine) + ), # TODO this isn't very efficient + "config": self.model.to_dict(), + } + torch.save(model, engine_path) + + return True + + @torch.cuda.nvtx.range("set_bindings_shape") + def set_bindings_shape(self, inputs): + for k in inputs: + shape = inputs[k].shape + shape = [shape[0] // self.curr_split_batch] + list(shape[1:]) + self.context.set_input_shape(k, shape) + + @torch.cuda.nvtx.range("setup_tensors") + def setup_tensors(self, model_inputs): + raise NotImplementedError + + @classmethod + def load_trt_model(cls, engine_path, model_type): + try: + engine = torch.load(engine_path) + config = engine["config"] + model, helper = get_model_from_version(model_type, config) + unet = cls(helper) + unet.engine = runtime.deserialize_cuda_engine( + engine["engine"].numpy().tobytes() + ) + except: + model, helper = get_model_from_version(model_type, {}) + unet = cls(helper) + with open(engine_path, "rb") as f: + unet.engine = runtime.deserialize_cuda_engine(f.read()) + + if unet.engine is None: + raise Exception("Failed to load Engine") + + unet.context = unet.engine.create_execution_context() + model.diffusion_model = unet + model.memory_required = ( + lambda *args, **kwargs: 0 + ) # always pass inputs batched up as much as possible + return model + + @torch.cuda.nvtx.range("__call__") + def __call__(self): + raise NotImplementedError + + +class TRTDiffusionBackbone(TRTModel): + + @torch.cuda.nvtx.range("setup_tensors") + def setup_tensors(self, model_inputs): + self.current_shape = model_inputs["x"].shape + self.extra_inputs = {} + + dims = self.engine.get_tensor_profile_shape(self.engine.get_tensor_name(0), 0) + batch_size, _, height, width = self.current_shape + height *= 8 + width *= 8 + _, context_len, _ = model_inputs["context"].shape + min_batch = dims[0][0] + opt_batch = dims[1][0] + max_batch = dims[2][0] + # Split batch if our batch is bigger than the max batch size the trt engine supports + for i in range(max_batch, min_batch - 1, -1): + if batch_size % i == 0: + self.curr_split_batch = batch_size // i + break + + self.set_bindings_shape(model_inputs) + + # Inputs missing, use zero + max_memory = 0 + for name in self.input_names: + if name in model_inputs: + continue + shape = self.model.get_input_shapes_by_key( + name, + batch_size=batch_size, + height=height, + width=width, + context_len=context_len, + ) + shape = (shape[0] // self.curr_split_batch, *shape[1:]) + self.context.set_input_shape(name, shape) + self.extra_inputs[name] = 0 + max_memory = max(prod(shape), max_memory) + self.zero_pool = torch.zeros( + max_memory, device=self.device, dtype=self.dtype + ).contiguous() + + self.output_shapes = {} + for name in self.output_names: + shape = list(self.engine.get_tensor_shape("h")) + for idx in range(len(shape)): + if shape[idx] == -1: + shape[idx] = model_inputs["x"].shape[idx] + if idx == 0: + shape[idx] = batch_size + self.output_shapes[name] = shape + + @torch.cuda.nvtx.range("__call__") + def __call__( + self, + x, + timesteps, + context, + y=None, + control=None, + transformer_options=None, + **kwargs, + ): + model_inputs = {"x": x, "timesteps": timesteps, "context": context} + + if y is not None: + model_inputs["y"] = y + + if self.model.use_control and control is not None: + for control_layer, control_tensors in control.items(): + for i, tensor in enumerate(control_tensors): + model_inputs[f"{control_layer}_control_{i}"] = tensor + + for k, v in kwargs.items(): + # TODO actually needed? model_inputs[k] = v + pass + + if self.current_shape != x.shape: + self.setup_tensors(model_inputs) + + model_inputs["h"] = torch.empty( + self.output_shapes["h"], + device=self.device, + dtype=trt_datatype_to_torch(self.engine.get_tensor_dtype("h")), + ).contiguous() + + for k in model_inputs: + trt_dtype = trt_datatype_to_torch(self.engine.get_tensor_dtype(k)) + if model_inputs[k].dtype != trt_dtype: + model_inputs[k] = model_inputs[k].to(trt_dtype) + + torch.cuda.nvtx.range_push("infer") + stream = torch.cuda.default_stream(x.device) + for i in range(self.curr_split_batch): + for k, v in model_inputs.items(): + self.context.set_tensor_address( + k, v[(v.shape[0] // self.curr_split_batch) * i :].data_ptr() + ) + for k in self.extra_inputs.keys(): + self.context.set_tensor_address(k, self.zero_pool.data_ptr()) + self.context.execute_async_v3(stream_handle=stream.cuda_stream) + torch.cuda.nvtx.range_pop() + + return model_inputs["h"] + + +class TQDMProgressMonitor(trt.IProgressMonitor): + def __init__(self): + trt.IProgressMonitor.__init__(self) + self._active_phases = {} + self._step_result = True + self.max_indent = 5 + + def phase_start(self, phase_name, parent_phase, num_steps): + leave = False + try: + if parent_phase is not None: + nbIndents = ( + self._active_phases.get(parent_phase, {}).get( + "nbIndents", self.max_indent + ) + + 1 + ) + if nbIndents >= self.max_indent: + return + else: + nbIndents = 0 + leave = True + self._active_phases[phase_name] = { + "tq": tqdm( + total=num_steps, desc=phase_name, leave=leave, position=nbIndents + ), + "nbIndents": nbIndents, + "parent_phase": parent_phase, + } + except KeyboardInterrupt: + # The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete. + _step_result = False + + def phase_finish(self, phase_name): + try: + if phase_name in self._active_phases.keys(): + self._active_phases[phase_name]["tq"].update( + self._active_phases[phase_name]["tq"].total + - self._active_phases[phase_name]["tq"].n + ) + + parent_phase = self._active_phases[phase_name].get("parent_phase", None) + while parent_phase is not None: + self._active_phases[parent_phase]["tq"].refresh() + parent_phase = self._active_phases[parent_phase].get( + "parent_phase", None + ) + if ( + self._active_phases[phase_name]["parent_phase"] + in self._active_phases.keys() + ): + self._active_phases[ + self._active_phases[phase_name]["parent_phase"] + ]["tq"].refresh() + del self._active_phases[phase_name] + pass + except KeyboardInterrupt: + _step_result = False + + def step_complete(self, phase_name, step): + try: + if phase_name in self._active_phases.keys(): + self._active_phases[phase_name]["tq"].update( + step - self._active_phases[phase_name]["tq"].n + ) + return self._step_result + except KeyboardInterrupt: + # There is no need to propagate this exception to TensorRT. We can simply cancel the build. + return False diff --git a/tensorrt_loader.py b/tensorrt_loader.py deleted file mode 100644 index 37e9a40..0000000 --- a/tensorrt_loader.py +++ /dev/null @@ -1,180 +0,0 @@ -#Put this in the custom_nodes folder, put your tensorrt engine files in ComfyUI/models/tensorrt/ (you will have to create the directory) - -import torch -import os - -import comfy.model_base -import comfy.model_management -import comfy.model_patcher -import comfy.supported_models -import folder_paths -from .onnx_utils.export import ModelType - -if "tensorrt" in folder_paths.folder_names_and_paths: - folder_paths.folder_names_and_paths["tensorrt"][0].append( - os.path.join(folder_paths.models_dir, "tensorrt")) - folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") -else: - folder_paths.folder_names_and_paths["tensorrt"] = ( - [os.path.join(folder_paths.models_dir, "tensorrt")], {".engine"}) - -import tensorrt as trt - -trt.init_libnvinfer_plugins(None, "") - -logger = trt.Logger(trt.Logger.INFO) -runtime = trt.Runtime(logger) - -# Is there a function that already exists for this? -def trt_datatype_to_torch(datatype): - if datatype == trt.float16: - return torch.float16 - elif datatype == trt.float32: - return torch.float32 - elif datatype == trt.int32: - return torch.int32 - elif datatype == trt.bfloat16: - return torch.bfloat16 - -class TrTUnet: - def __init__(self, engine_path): - with open(engine_path, "rb") as f: - self.engine = runtime.deserialize_cuda_engine(f.read()) - self.context = self.engine.create_execution_context() - self.dtype = torch.float16 - - def set_bindings_shape(self, inputs, split_batch): - for k in inputs: - shape = inputs[k].shape - shape = [shape[0] // split_batch] + list(shape[1:]) - self.context.set_input_shape(k, shape) - - def __call__(self, x, timesteps, context, y=None, control=None, transformer_options=None, **kwargs): - model_inputs = {"x": x, "timesteps": timesteps, "context": context} - - if y is not None: - model_inputs["y"] = y - - for i in range(len(model_inputs), self.engine.num_io_tensors - 1): - name = self.engine.get_tensor_name(i) - model_inputs[name] = kwargs[name] - - batch_size = x.shape[0] - dims = self.engine.get_tensor_profile_shape(self.engine.get_tensor_name(0), 0) - min_batch = dims[0][0] - opt_batch = dims[1][0] - max_batch = dims[2][0] - - #Split batch if our batch is bigger than the max batch size the trt engine supports - for i in range(max_batch, min_batch - 1, -1): - if batch_size % i == 0: - curr_split_batch = batch_size // i - break - - self.set_bindings_shape(model_inputs, curr_split_batch) - - model_inputs_converted = {} - for k in model_inputs: - data_type = self.engine.get_tensor_dtype(k) - model_inputs_converted[k] = model_inputs[k].to(dtype=trt_datatype_to_torch(data_type)) - - output_binding_name = self.engine.get_tensor_name(len(model_inputs)) - out_shape = self.engine.get_tensor_shape(output_binding_name) - out_shape = list(out_shape) - - #for dynamic profile case where the dynamic params are -1 - for idx in range(len(out_shape)): - if out_shape[idx] == -1: - out_shape[idx] = x.shape[idx] - else: - if idx == 0: - out_shape[idx] *= curr_split_batch - - out = torch.empty(out_shape, - device=x.device, - dtype=trt_datatype_to_torch(self.engine.get_tensor_dtype(output_binding_name))) - model_inputs_converted[output_binding_name] = out - - stream = torch.cuda.default_stream(x.device) - for i in range(curr_split_batch): - for k in model_inputs_converted: - x = model_inputs_converted[k] - self.context.set_tensor_address(k, x[(x.shape[0] // curr_split_batch) * i:].data_ptr()) - self.context.execute_async_v3(stream_handle=stream.cuda_stream) - # stream.synchronize() #don't need to sync stream since it's the default torch one - return out - - def load_state_dict(self, sd, strict=False): - pass - - def state_dict(self): - return {} - - -class TensorRTLoader: - @classmethod - def INPUT_TYPES(s): - return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ), - "model_type": (ModelType.list(), ), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "load_unet" - CATEGORY = "TensorRT" - - def load_unet(self, unet_name, model_type): - unet_path = folder_paths.get_full_path("tensorrt", unet_name) - if not os.path.isfile(unet_path): - raise FileNotFoundError(f"File {unet_path} does not exist") - unet = TrTUnet(unet_path) - if model_type == ModelType.SDXL_BASE: - conf = comfy.supported_models.SDXL({"adm_in_channels": 2816}) - conf.unet_config["disable_unet_model_creation"] = True - model = comfy.model_base.SDXL(conf) - elif model_type == ModelType.SDXL_REFINER: - conf = comfy.supported_models.SDXLRefiner( - {"adm_in_channels": 2560}) - conf.unet_config["disable_unet_model_creation"] = True - model = comfy.model_base.SDXLRefiner(conf) - elif model_type == ModelType.SD1x: - conf = comfy.supported_models.SD15({}) - conf.unet_config["disable_unet_model_creation"] = True - model = comfy.model_base.BaseModel(conf) - elif model_type == ModelType.SD2x768v: - conf = comfy.supported_models.SD20({}) - conf.unet_config["disable_unet_model_creation"] = True - model = comfy.model_base.BaseModel(conf, model_type=comfy.model_base.ModelType.V_PREDICTION) - elif model_type == ModelType.SVD: - conf = comfy.supported_models.SVD_img2vid({}) - conf.unet_config["disable_unet_model_creation"] = True - model = conf.get_model({}) - elif model_type == ModelType.SD3: - conf = comfy.supported_models.SD3({}) - conf.unet_config["disable_unet_model_creation"] = True - model = conf.get_model({}) - elif model_type == ModelType.AuraFlow: - conf = comfy.supported_models.AuraFlow({}) - conf.unet_config["disable_unet_model_creation"] = True - model = conf.get_model({}) - elif model_type == ModelType.FLUX_DEV: - conf = comfy.supported_models.Flux({}) - conf.unet_config["disable_unet_model_creation"] = True - model = conf.get_model({}) - unet.dtype = torch.bfloat16 #TODO: autodetect - elif model_type == ModelType.FLUX_SCHNELL: - conf = comfy.supported_models.FluxSchnell({}) - conf.unet_config["disable_unet_model_creation"] = True - model = conf.get_model({}) - unet.dtype = torch.bfloat16 #TODO: autodetect - model.diffusion_model = unet - model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting - - return (comfy.model_patcher.ModelPatcher(model, - load_device=comfy.model_management.get_torch_device(), - offload_device=comfy.model_management.unet_offload_device()),) - -NODE_CLASS_MAPPINGS = { - "TensorRTLoader": TensorRTLoader, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "TensorRTLoader": "TensorRT Loader" -} \ No newline at end of file diff --git a/tensorrt_convert.py b/tensorrt_nodes.py similarity index 62% rename from tensorrt_convert.py rename to tensorrt_nodes.py index 95a2e6c..8b88d5e 100644 --- a/tensorrt_convert.py +++ b/tensorrt_nodes.py @@ -1,20 +1,15 @@ +# Put this in the custom_nodes folder, put your tensorrt engine files in ComfyUI/models/tensorrt/ (you will have to create the directory) import os import time -import comfy.model_management +from typing import Optional -import tensorrt as trt +import comfy.model_management +import comfy.model_patcher import folder_paths -from tqdm import tqdm -from .onnx_utils.export import ( - get_io_names, - get_extra_input, - get_shape, - export_onnx, - ModelType - ) -# TODO: -# Make it more generic: less model specific code +from .models import supported_models, detect_version_from_model, get_helper_from_model +from .onnx_utils.export import export_onnx +from .tensorrt_diffusion_model import TRTDiffusionBackbone # add output directory to tensorrt search path if "tensorrt" in folder_paths.folder_names_and_paths: @@ -28,83 +23,43 @@ {".engine"}, ) -class TQDMProgressMonitor(trt.IProgressMonitor): - def __init__(self): - trt.IProgressMonitor.__init__(self) - self._active_phases = {} - self._step_result = True - self.max_indent = 5 - def phase_start(self, phase_name, parent_phase, num_steps): - leave = False - try: - if parent_phase is not None: - nbIndents = ( - self._active_phases.get(parent_phase, {}).get( - "nbIndents", self.max_indent - ) - + 1 - ) - if nbIndents >= self.max_indent: - return - else: - nbIndents = 0 - leave = True - self._active_phases[phase_name] = { - "tq": tqdm( - total=num_steps, desc=phase_name, leave=leave, position=nbIndents - ), - "nbIndents": nbIndents, - "parent_phase": parent_phase, +class TensorRTLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "unet_name": (folder_paths.get_filename_list("tensorrt"),), + "model_type": (list(supported_models.keys()),), } - except KeyboardInterrupt: - # The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete. - _step_result = False + } - def phase_finish(self, phase_name): - try: - if phase_name in self._active_phases.keys(): - self._active_phases[phase_name]["tq"].update( - self._active_phases[phase_name]["tq"].total - - self._active_phases[phase_name]["tq"].n - ) + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_unet" + CATEGORY = "TensorRT" - parent_phase = self._active_phases[phase_name].get("parent_phase", None) - while parent_phase is not None: - self._active_phases[parent_phase]["tq"].refresh() - parent_phase = self._active_phases[parent_phase].get( - "parent_phase", None - ) - if ( - self._active_phases[phase_name]["parent_phase"] - in self._active_phases.keys() - ): - self._active_phases[ - self._active_phases[phase_name]["parent_phase"] - ]["tq"].refresh() - del self._active_phases[phase_name] - pass - except KeyboardInterrupt: - _step_result = False + def load_unet(self, unet_name, model_type): + unet_path = folder_paths.get_full_path("tensorrt", unet_name) + model = TRTDiffusionBackbone.load_trt_model(unet_path, model_type) + return ( + comfy.model_patcher.ModelPatcher( + model, + load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device(), + ), + ) - def step_complete(self, phase_name, step): - try: - if phase_name in self._active_phases.keys(): - self._active_phases[phase_name]["tq"].update( - step - self._active_phases[phase_name]["tq"].n - ) - return self._step_result - except KeyboardInterrupt: - # There is no need to propagate this exception to TensorRT. We can simply cancel the build. - return False - class TRT_MODEL_CONVERSION_BASE: def __init__(self): self.output_dir = folder_paths.get_output_directory() self.temp_dir = folder_paths.get_temp_directory() self.timing_cache_path = os.path.normpath( - os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt")) + os.path.join( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt" + ) + ) ) RETURN_TYPES = () @@ -116,24 +71,6 @@ def __init__(self): def INPUT_TYPES(s): raise NotImplementedError - # Sets up the builder to use the timing cache file, and creates it if it does not already exist - def _setup_timing_cache(self, config: trt.IBuilderConfig): - buffer = b"" - if os.path.exists(self.timing_cache_path): - with open(self.timing_cache_path, mode="rb") as timing_cache_file: - buffer = timing_cache_file.read() - print("Read {} bytes from timing cache.".format(len(buffer))) - else: - print("No timing cache found; Initializing a new one.") - timing_cache: trt.ITimingCache = config.create_timing_cache(buffer) - config.set_timing_cache(timing_cache, ignore_mismatch=True) - - # Saves the config's timing cache to file - def _save_timing_cache(self, config: trt.IBuilderConfig): - timing_cache: trt.ITimingCache = config.get_timing_cache() - with open(self.timing_cache_path, "wb") as timing_cache_file: - timing_cache_file.write(memoryview(timing_cache.serialize())) - def _convert( self, model, @@ -152,9 +89,9 @@ def _convert( context_max, num_video_frames, is_static: bool, - output_onnx: bool = False, + output_onnx: Optional[str] = None, ): - if not output_onnx: + if output_onnx is None: output_onnx = os.path.normpath( os.path.join( os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx" @@ -163,55 +100,16 @@ def _convert( os.makedirs(os.path.dirname(output_onnx), exist_ok=True) comfy.model_management.unload_all_models() - comfy.model_management.load_models_gpu([model], force_patch_weights=True, force_full_load=True) - export_onnx(model, output_onnx) - - model_type = ModelType.detect_version(model) - y_dim, extra_input = get_extra_input(model, model_type) - input_names, _, _ = get_io_names(y_dim, extra_input) - inputs_shapes_min = get_shape(model, model_type, batch_size_min, width_min, height_min, context_min, num_video_frames, y_dim, extra_input) - inputs_shapes_opt = get_shape(model, model_type, batch_size_opt, width_opt, height_opt, context_opt, num_video_frames, y_dim, extra_input) - inputs_shapes_max = get_shape(model, model_type, batch_size_max, width_max, height_max, context_max, num_video_frames, y_dim, extra_input) - - comfy.model_management.unload_all_models() - comfy.model_management.soft_empty_cache() - - # TRT conversion starts here - logger = trt.Logger(trt.Logger.INFO) - builder = trt.Builder(logger) - - network = builder.create_network( - 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) - ) - parser = trt.OnnxParser(network, logger) - success = parser.parse_from_file(output_onnx) - for idx in range(parser.num_errors): - print(parser.get_error(idx)) - - if not success: - print("ONNX load ERROR") - return () - - config = builder.create_builder_config() - profile = builder.create_optimization_profile() - self._setup_timing_cache(config) - config.progress_monitor = TQDMProgressMonitor() - - prefix_encode = "" - for k in range(len(input_names)): - min_shape = inputs_shapes_min[k] - opt_shape = inputs_shapes_opt[k] - max_shape = inputs_shapes_max[k] - profile.set_shape(input_names[k], min_shape, opt_shape, max_shape) - - # Encode shapes to filename - encode = lambda a: ".".join(map(lambda x: str(x), a)) - prefix_encode += "{}#{}#{}#{};".format( - input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape) + comfy.model_management.load_models_gpu( + [model], force_patch_weights=True, force_full_load=True ) + export_onnx(model, output_onnx) - config.add_optimization_profile(profile) + model_version = detect_version_from_model(model) + model_helper = get_helper_from_model(model) + trt_model = TRTDiffusionBackbone(model_helper) + filename_prefix = f"{filename_prefix}_{model_version}" if is_static: filename_prefix = "{}_${}".format( filename_prefix, @@ -249,10 +147,6 @@ def _convert( ), ) - serialized_engine = builder.build_serialized_network(network, config) - if serialized_engine is None: - raise Exception("Failed to build Engine") - full_output_folder, filename, counter, subfolder, filename_prefix = ( folder_paths.get_save_image_path(filename_prefix, self.output_dir) ) @@ -260,11 +154,36 @@ def _convert( full_output_folder, f"{filename}_{counter:05}_.engine" ) - with open(output_trt_engine, "wb") as f: - f.write(serialized_engine) - - self._save_timing_cache(config) - + batch_multiplier = ( + 2 if model_helper.is_conditional else 1 + ) # TODO lets see if we really want this + if model_version == "SVD_img2vid": + batch_multiplier *= num_video_frames + success = trt_model.build( + output_onnx, + output_trt_engine, + self.timing_cache_path, + opt_config={ + "batch_size": batch_size_opt * batch_multiplier, + "height": height_opt, + "width": width_opt, + "context_len": context_opt * model_helper.context_len, + }, + min_config={ + "batch_size": batch_size_min * batch_multiplier, + "height": height_min, + "width": width_min, + "context_len": context_min * model_helper.context_len, + }, + max_config={ + "batch_size": batch_size_max * batch_multiplier, + "height": height_max, + "width": width_max, + "context_len": context_max * model_helper.context_len, + }, + ) + if not success: + raise RuntimeError("Engine Build Failed") return () @@ -398,7 +317,7 @@ def INPUT_TYPES(s): }, "optional": { "onnx_model_path": ("STRING", {"default": "", "forceInput": True}), - } + }, } def convert( @@ -499,7 +418,7 @@ def INPUT_TYPES(s): }, "optional": { "onnx_model_path": ("STRING", {"default": "", "forceInput": True}), - } + }, } def convert( @@ -535,10 +454,12 @@ def convert( NODE_CLASS_MAPPINGS = { + "TensorRTLoader": TensorRTLoader, "DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION, "STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION, } NODE_DISPLAY_NAME_MAPPINGS = { + "TensorRTLoader": "TensorRT Loader", "DYNAMIC_TRT_MODEL_CONVERSION": "DYNAMIC TRT_MODEL CONVERSION", "STATIC TRT_MODEL CONVERSION": "STATIC_TRT_MODEL_CONVERSION", } From 1991bafb333f8dcec9dc8d15d67f146a74da2200 Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 13 Nov 2024 01:47:28 -0800 Subject: [PATCH 5/6] Use zero-copy approach for storing TRT engine as torch sd --- models/flux.py | 10 ++++------ models/sd3.py | 6 +++--- models/supported_models.py | 4 +--- onnx_utils/export.py | 3 +-- tensorrt_diffusion_model.py | 5 ++--- 5 files changed, 11 insertions(+), 17 deletions(-) diff --git a/models/flux.py b/models/flux.py index d161c31..c14a26f 100644 --- a/models/flux.py +++ b/models/flux.py @@ -29,12 +29,11 @@ def __init__( self.input_config.update(self.extra_input) if self.use_control: - self.control = self.get_control(double_blocks, single_blocks) + self.control = self.get_control() self.input_config.update(self.control) def to_dict(self): return { - self.__name__: { "context_dim": self.context_dim, "input_channels": self.input_channels, "y_dim": self.y_dim, @@ -42,18 +41,17 @@ def to_dict(self): "double_blocks": self.double_blocks, "single_blocks": self.single_blocks, "use_control": self.use_control, - } } - def get_control(self, double_blocks: int, single_blocks: int): + def get_control(self): control_input = {} - for i in range(double_blocks): + for i in range(self.double_blocks): control_input[f"input_control_{i}"] = { "batch": "{batch_size}", "ids": "({height}*{width}//(8*2)**2)", "hidden_size": self.hidden_size, } - for i in range(single_blocks): + for i in range(self.single_blocks): control_input[f"output_control_{i}"] = { "batch": "{batch_size}", "ids": "({height}*{width}//(8*2)**2)", diff --git a/models/sd3.py b/models/sd3.py index 11c3111..a961a39 100644 --- a/models/sd3.py +++ b/models/sd3.py @@ -27,7 +27,7 @@ def __init__( self.input_config.update(self.extra_input) if self.use_control: - self.control = self.get_control(output_blocks) + self.control = self.get_control() self.input_config.update(self.control) def to_dict(self): @@ -40,9 +40,9 @@ def to_dict(self): "use_control": self.use_control, } - def get_control(self, output_blocks: int): + def get_control(self): control_input = {} - for i in range(output_blocks): + for i in range(self.output_blocks): control_input[f"output_control_{i}"] = { "batch": "{batch_size}", "ids": "({height}*{width}//(8*2)**2)", diff --git a/models/supported_models.py b/models/supported_models.py index 1b09fa1..607c8fa 100644 --- a/models/supported_models.py +++ b/models/supported_models.py @@ -63,9 +63,7 @@ def get_helper_from_version(model_version: str, config: dict = {}) -> TRTModelUt def get_helper_from_model(model) -> TRTModelUtil: model_version = detect_version_from_model(model) - helper_cls = supported_models.get(model_version, None) - if helper_cls is None: - raise NotImplementedError("{} is not supported.".format(model_version)) + helper_cls = supported_models.get(model_version, TRTModelUtil) return helper_cls.from_model(model) diff --git a/onnx_utils/export.py b/onnx_utils/export.py index 0e4554a..1532d98 100644 --- a/onnx_utils/export.py +++ b/onnx_utils/export.py @@ -169,7 +169,6 @@ def export_onnx( width: int = 512, num_video_frames: int = 14, context_multiplier: int = 1, - use_control: bool = True, ): model_version = detect_version_from_model(model) model_helper = get_helper_from_model(model) @@ -192,7 +191,7 @@ def export_onnx( ) inputs = get_sample_input(input_shapes, dtype, device) backbone = get_backbone( - model, model_version, input_names, num_video_frames, use_control + model, model_version, input_names, num_video_frames, model_helper.use_control ) dir, name = os.path.split(path) diff --git a/tensorrt_diffusion_model.py b/tensorrt_diffusion_model.py index 684fb8c..a0846f3 100644 --- a/tensorrt_diffusion_model.py +++ b/tensorrt_diffusion_model.py @@ -120,10 +120,9 @@ def build( if self.engine is None: raise Exception("Failed to build Engine") + engine_view = memoryview(self.engine) model = { - "engine": torch.ByteTensor( - bytearray(self.engine) - ), # TODO this isn't very efficient + "engine": torch.frombuffer(engine_view, dtype=torch.uint8), "config": self.model.to_dict(), } torch.save(model, engine_path) From d2fcef4d9063b0c88627e1815ef021fab391515b Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 13 Nov 2024 03:15:44 -0800 Subject: [PATCH 6/6] clean-up: reformat code, fix typos, optimize imports --- .github/workflows/publish.yml | 2 +- README.md | 72 +++++----- __init__.py | 5 +- models/__init__.py | 2 +- models/auraflow.py | 2 +- models/baseline.py | 17 ++- models/flux.py | 76 ++++++----- models/sd3.py | 19 +-- models/sd_unet.py | 243 +++++++++++++++++----------------- models/supported_models.py | 5 +- onnx_nodes.py | 8 +- onnx_utils/export.py | 38 +++--- pyproject.toml | 4 +- tensorrt_diffusion_model.py | 84 ++++++------ tensorrt_nodes.py | 105 +++++++-------- 15 files changed, 351 insertions(+), 331 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index f10f0c4..e669db7 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -17,5 +17,5 @@ jobs: - name: Publish Custom Node uses: Comfy-Org/publish-node-action@main with: - ## Add your own personal access token to your Github Repository secrets and reference it here. + ## Add your own personal access token to your GitHub Repository secrets and reference it here. personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} \ No newline at end of file diff --git a/README.md b/README.md index 63df71a..69af30e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # TensorRT Node for ComfyUI -This node enables the best performance on NVIDIA RTX™ Graphics Cards - (GPUs) for Stable Diffusion by leveraging NVIDIA TensorRT. +This node enables the best performance on NVIDIA RTX™ Graphics Cards(GPUs) for Stable Diffusion by leveraging NVIDIA +TensorRT. Supports: @@ -11,8 +11,7 @@ Supports: - SDXL - SDXL Turbo - Stable Video Diffusion -- Stable Video Diffusion-XT  -- AuraFlow +- Stable Video Diffusion-XT- AuraFlow - Flux Requirements: @@ -31,7 +30,8 @@ Requirements: The recommended way to install these nodes is to use the [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) to easily install them to your ComfyUI instance. -You can also manually install them by git cloning the repo to your ComfyUI/custom_nodes folder and installing the requirements like: +You can also manually install them by git cloning the repo to your ComfyUI/custom_nodes folder and installing the +requirements like: ``` cd custom_nodes @@ -62,7 +62,7 @@ You have the option to build either dynamic or static TensorRT engines: Note: Most users will prefer dynamic engines, but static engines can be useful if you use a specific resolution + batch size combination most of the time. Static engines also require less VRAM; the wider the dynamic -range, the more VRAM will be consumed. +range, the more VRAM will be consumed. ## Instructions @@ -71,19 +71,19 @@ These .json files can be loaded in ComfyUI. ### Building A TensorRT Engine From a Checkpoint -1. Add a Load Checkpoint Node -2. Add either a Static Model TensorRT Conversion node or a Dynamic - Model TensorRT Conversion node to ComfyUI -3. ![](readme_images/image3.png) -4. Connect the Load Checkpoint Model output to the TensorRT Conversion - Node Model input. -5. ![](readme_images/image5.png) -6. ![](readme_images/image2.png) -7. To help identify the converted TensorRT model, provide a meaningful - filename prefix, add this filename after “tensorrt/” -8. ![](readme_images/image9.png) - -9. Click on Queue Prompt to start building the TensorRT Engines +1. Add a Load Checkpoint Node +2. Add either a Static Model TensorRT Conversion node or a Dynamic + Model TensorRT Conversion node to ComfyUI +3. ![](readme_images/image3.png) +4. Connect the Load Checkpoint Model output to the TensorRT Conversion + Node Model input. +5. ![](readme_images/image5.png) +6. ![](readme_images/image2.png) +7. To help identify the converted TensorRT model, provide a meaningful + filename prefix, add this filename after “tensorrt/” +8. ![](readme_images/image9.png) + +9. Click on Queue Prompt to start building the TensorRT Engines 10. ![](readme_images/image7.png) ![](readme_images/image11.png) @@ -96,9 +96,9 @@ the console. ![](readme_images/image4.png) -The first time generating an engine for a checkpoint will take awhile. +The first time generating an engine for a checkpoint will take a while. Additional engines generated thereafter for the same checkpoint will be -much faster. Generating engines can take anywhere from 3-10 minutes for +much faster. Generating engines can take anywhere from 3-10 minutes for the image generation models and 10-25 minutes for SVD. SVD-XT is an extremely extensive model - engine build times may take up to an hour. @@ -115,33 +115,33 @@ TensorRT Engines are loaded using the TensorRT Loader node. ComfyUI TensorRT engines are not yet compatible with ControlNets or LoRAs. Compatibility will be enabled in a future update. -1. Add a TensorRT Loader node -2. Note, if a TensorRT Engine has been created during a ComfyUI - session, it will not show up in the TensorRT Loader until the - ComfyUI interface has been refreshed (F5 to refresh browser). -3. ![](readme_images/image6.png) -4. Select a TensorRT Engine from the unet_name dropdown -5. Dynamic Engines will use a filename format of: +1. Add a TensorRT Loader node +2. Note, if a TensorRT Engine has been created during a ComfyUI + session, it will not show up in the TensorRT Loader until the + ComfyUI interface has been refreshed (F5 to refresh browser). +3. ![](readme_images/image6.png) +4. Select a TensorRT Engine from the unet_name dropdown +5. Dynamic Engines will use a filename format of:   -1. dyn-b-min-max-opt-h-min-max-opt-w-min-max-opt -2. dyn=dynamic, b=batch size, h=height, w=width +1. dyn-b-min-max-opt-h-min-max-opt-w-min-max-opt +2. dyn=dynamic, b=batch size, h=height, w=width   -6. Static Engine will use a filename format of: +6. Static Engine will use a filename format of:   -1. stat-b-opt-h-opt-w-opt -2. stat=static, b=batch size, h=height, w=width +1. stat-b-opt-h-opt-w-opt +2. stat=static, b=batch size, h=height, w=width   -7. ![](readme_images/image8.png) -8. The model_type must match the model type of the TensorRT engine. -9. ![](readme_images/image10.png) +7. ![](readme_images/image8.png) +8. The model_type must match the model type of the TensorRT engine. +9. ![](readme_images/image10.png) 10. The CLIP and VAE for the workflow will need to be utilized from the original model checkpoint, the MODEL output from the TensorRT Loader will be connected to the Sampler. diff --git a/__init__.py b/__init__.py index e4a40be..cd30b14 100644 --- a/__init__.py +++ b/__init__.py @@ -1,8 +1,7 @@ -from .tensorrt_nodes import NODE_CLASS_MAPPINGS as TRT_CLASS_MAP -from .tensorrt_nodes import NODE_DISPLAY_NAME_MAPPINGS as TRT_NAME_MAP - from .onnx_nodes import NODE_CLASS_MAPPING as ONNX_CLASS_MAP from .onnx_nodes import NODE_DISPLAY_NAME_MAPPINGS as ONNX_NAME_MAP +from .tensorrt_nodes import NODE_CLASS_MAPPINGS as TRT_CLASS_MAP +from .tensorrt_nodes import NODE_DISPLAY_NAME_MAPPINGS as TRT_NAME_MAP NODE_CLASS_MAPPINGS = TRT_CLASS_MAP | ONNX_CLASS_MAP NODE_DISPLAY_NAME_MAPPINGS = TRT_NAME_MAP | ONNX_NAME_MAP diff --git a/models/__init__.py b/models/__init__.py index 1e1e84e..52d25c3 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,3 +1,4 @@ +from .baseline import TRTModelUtil from .supported_models import ( supported_models, unsupported_models, @@ -6,4 +7,3 @@ get_helper_from_model, get_model_from_version, ) -from .baseline import TRTModelUtil diff --git a/models/auraflow.py b/models/auraflow.py index bb991e1..9becebc 100644 --- a/models/auraflow.py +++ b/models/auraflow.py @@ -3,7 +3,7 @@ class AuraFlow_TRT(TRTModelUtil): def __init__( - self, context_dim=2048, input_channels=4, context_len=256, **kwargs + self, context_dim=2048, input_channels=4, context_len=256, **kwargs ) -> None: super().__init__( context_dim=context_dim, diff --git a/models/baseline.py b/models/baseline.py index f72f30e..dfebe59 100644 --- a/models/baseline.py +++ b/models/baseline.py @@ -3,13 +3,13 @@ class TRTModelUtil: def __init__( - self, - context_dim: int, - input_channels: int, - context_len: int, - use_control: bool = False, - *args, - **kwargs, + self, + context_dim: int, + input_channels: int, + context_len: int, + use_control: bool = False, + *args, + **kwargs, ) -> None: super().__init__(*args, **kwargs) self.context_dim = context_dim @@ -100,3 +100,6 @@ def get_control(self, *args, **kwargs) -> dict: @classmethod def from_model(cls, model, **kwargs): raise NotImplementedError + + def model_attributes(self, **kwargs) -> dict: + return {} diff --git a/models/flux.py b/models/flux.py index c14a26f..31d273b 100644 --- a/models/flux.py +++ b/models/flux.py @@ -1,18 +1,19 @@ -from .baseline import TRTModelUtil import torch +from .baseline import TRTModelUtil + class FLuxBase(TRTModelUtil): def __init__( - self, - context_dim: int, - input_channels: int, - y_dim: int, - hidden_size: int, - double_blocks: int, - single_blocks: int, - *args, - **kwargs, + self, + context_dim: int, + input_channels: int, + y_dim: int, + hidden_size: int, + double_blocks: int, + single_blocks: int, + *args, + **kwargs, ) -> None: super().__init__(context_dim, input_channels, 256, *args, **kwargs) @@ -34,13 +35,13 @@ def __init__( def to_dict(self): return { - "context_dim": self.context_dim, - "input_channels": self.input_channels, - "y_dim": self.y_dim, - "hidden_size": self.hidden_size, - "double_blocks": self.double_blocks, - "single_blocks": self.single_blocks, - "use_control": self.use_control, + "context_dim": self.context_dim, + "input_channels": self.input_channels, + "y_dim": self.y_dim, + "hidden_size": self.hidden_size, + "double_blocks": self.double_blocks, + "single_blocks": self.single_blocks, + "use_control": self.use_control, } def get_control(self): @@ -74,17 +75,22 @@ def from_model(cls, model, **kwargs): **kwargs, ) + def model_attributes(self, **kwargs) -> dict: + return { + "double_blocks": [None, ] * self.double_blocks, + } + class Flux_TRT(FLuxBase): def __init__( - self, - context_dim=4096, - input_channels=16, - y_dim=768, - hidden_size=3072, - double_blocks=19, - single_blocks=28, - **kwargs, + self, + context_dim=4096, + input_channels=16, + y_dim=768, + hidden_size=3072, + double_blocks=19, + single_blocks=28, + **kwargs, ): super().__init__( context_dim=context_dim, @@ -97,20 +103,20 @@ def __init__( ) @classmethod - def from_model(cls, model): + def from_model(cls, model, **kwargs): return super(Flux_TRT, cls).from_model(model, use_control=True) class FluxSchnell_TRT(FLuxBase): def __init__( - self, - context_dim=4096, - input_channels=16, - y_dim=768, - hidden_size=3072, - double_blocks=19, - single_blocks=28, - **kwargs, + self, + context_dim=4096, + input_channels=16, + y_dim=768, + hidden_size=3072, + double_blocks=19, + single_blocks=28, + **kwargs, ): super().__init__( context_dim=context_dim, @@ -123,5 +129,5 @@ def __init__( ) @classmethod - def from_model(cls, model): + def from_model(cls, model, **kwargs): return super(FluxSchnell_TRT, cls).from_model(model, use_control=True) diff --git a/models/sd3.py b/models/sd3.py index a961a39..fa2521b 100644 --- a/models/sd3.py +++ b/models/sd3.py @@ -1,17 +1,18 @@ -from .baseline import TRTModelUtil import torch +from .baseline import TRTModelUtil + class SD3_TRT(TRTModelUtil): def __init__( - self, - context_dim: int = 4096, - input_channels: int = 16, - y_dim: int = 2048, - hidden_size: int = 1536, - output_blocks: int = 24, - *args, - **kwargs, + self, + context_dim: int = 4096, + input_channels: int = 16, + y_dim: int = 2048, + hidden_size: int = 1536, + output_blocks: int = 24, + *args, + **kwargs, ) -> None: super().__init__(context_dim, input_channels, 77, *args, **kwargs) diff --git a/models/sd_unet.py b/models/sd_unet.py index 742158b..d7d42a2 100644 --- a/models/sd_unet.py +++ b/models/sd_unet.py @@ -1,6 +1,4 @@ -from .baseline import TRTModelUtil import torch - from comfy.supported_models import ( SD15, SD20, @@ -17,19 +15,21 @@ SDXL_instructpix2pix, ) +from .baseline import TRTModelUtil + class UNetTRT(TRTModelUtil): def __init__( - self, - context_dim: int, - input_channels: int, - y_dim: int, - hidden_size: int, - channel_mult: tuple[int], - num_res_blocks: tuple[int], - context_len: int = 77, - *args, - **kwargs, + self, + context_dim: int, + input_channels: int, + y_dim: int, + hidden_size: int, + channel_mult: tuple[int], + num_res_blocks: tuple[int], + context_len: int = 77, + *args, + **kwargs, ) -> None: super().__init__(context_dim, input_channels, context_len, *args, **kwargs) @@ -39,7 +39,7 @@ def __init__( self.channel_mult = channel_mult self.num_res_blocks = num_res_blocks - self.set_block_chans() + self.input_block_chans = self.set_block_chans() if self.y_dim: self.input_config.update({"y": {"batch": "{batch_size}", "y_dim": y_dim}}) @@ -62,7 +62,7 @@ def set_block_chans(self): ch = out_ch ds *= 2 input_block_chans.append((ch, ds)) - self.input_block_chans = input_block_chans + return input_block_chans @classmethod def from_model(cls, model, **kwargs): @@ -114,7 +114,6 @@ def get_control(self): f"width{d}": "{width}//(8*" + str(d) + ")", } return control_input - # return {} def get_dtype(self): return torch.float16 @@ -122,14 +121,14 @@ def get_dtype(self): class SD15_TRT(UNetTRT): def __init__( - self, - context_dim=SD15.unet_config["context_dim"], - input_channels=4, - y_dim=0, - hidden_size=SD15.unet_config["model_channels"], - channel_mult=(1, 2, 4, 4), - num_res_blocks=(2, 2, 2, 2), - **kwargs, + self, + context_dim=SD15.unet_config["context_dim"], + input_channels=4, + y_dim=0, + hidden_size=SD15.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, ): super().__init__( context_dim, @@ -142,20 +141,20 @@ def __init__( ) @classmethod - def from_model(cls, model): + def from_model(cls, model, **kwargs): return super(SD15_TRT, cls).from_model(model, use_control=True) class SD20_TRT(UNetTRT): def __init__( - self, - context_dim=SD20.unet_config["context_dim"], - input_channels=4, - y_dim=0, - hidden_size=SD20.unet_config["model_channels"], - channel_mult=(1, 2, 4, 4), - num_res_blocks=(2, 2, 2, 2), - **kwargs, + self, + context_dim=SD20.unet_config["context_dim"], + input_channels=4, + y_dim=0, + hidden_size=SD20.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, ): super().__init__( context_dim, @@ -168,20 +167,20 @@ def __init__( ) @classmethod - def from_model(cls, model): + def from_model(cls, model, **kwargs): return super(SD20_TRT, cls).from_model(model, use_control=True) class SD21UnclipL_TRT(UNetTRT): def __init__( - self, - context_dim=SD21UnclipL.unet_config["context_dim"], - input_channels=4, - y_dim=SD21UnclipL.unet_config["adm_in_channels"], - hidden_size=SD21UnclipL.unet_config["model_channels"], - channel_mult=(1, 2, 4, 4), - num_res_blocks=(2, 2, 2, 2), - **kwargs, + self, + context_dim=SD21UnclipL.unet_config["context_dim"], + input_channels=4, + y_dim=SD21UnclipL.unet_config["adm_in_channels"], + hidden_size=SD21UnclipL.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, ): super().__init__( context_dim, @@ -196,14 +195,14 @@ def __init__( class SD21UnclipH_TRT(UNetTRT): def __init__( - self, - context_dim=SD21UnclipH.unet_config["context_dim"], - input_channels=4, - y_dim=SD21UnclipH.unet_config["adm_in_channels"], - hidden_size=SD21UnclipH.unet_config["model_channels"], - channel_mult=(1, 2, 4, 4), - num_res_blocks=(2, 2, 2, 2), - **kwargs, + self, + context_dim=SD21UnclipH.unet_config["context_dim"], + input_channels=4, + y_dim=SD21UnclipH.unet_config["adm_in_channels"], + hidden_size=SD21UnclipH.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, ): super().__init__( context_dim, @@ -218,14 +217,14 @@ def __init__( class SDXLRefiner_TRT(UNetTRT): def __init__( - self, - context_dim=SDXLRefiner.unet_config["context_dim"], - input_channels=4, - y_dim=SDXLRefiner.unet_config["adm_in_channels"], - hidden_size=SDXLRefiner.unet_config["model_channels"], - channel_mult=(1, 2, 4, 4), - num_res_blocks=(2, 2, 2, 2), - **kwargs, + self, + context_dim=SDXLRefiner.unet_config["context_dim"], + input_channels=4, + y_dim=SDXLRefiner.unet_config["adm_in_channels"], + hidden_size=SDXLRefiner.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, ): super().__init__( context_dim, @@ -240,14 +239,14 @@ def __init__( class SDXL_TRT(UNetTRT): def __init__( - self, - context_dim=SDXL.unet_config["context_dim"], - input_channels=4, - y_dim=SDXL.unet_config["adm_in_channels"], - hidden_size=SDXL.unet_config["model_channels"], - channel_mult=(1, 2, 4), - num_res_blocks=(2, 2, 2), - **kwargs, + self, + context_dim=SDXL.unet_config["context_dim"], + input_channels=4, + y_dim=SDXL.unet_config["adm_in_channels"], + hidden_size=SDXL.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), + **kwargs, ): super().__init__( context_dim, @@ -260,20 +259,20 @@ def __init__( ) @classmethod - def from_model(cls, model): + def from_model(cls, model, **kwargs): return super(SDXL_TRT, cls).from_model(model, use_control=True) class SSD1B_TRT(UNetTRT): def __init__( - self, - context_dim=SSD1B.unet_config["context_dim"], - input_channels=4, - y_dim=SSD1B.unet_config["adm_in_channels"], - hidden_size=SSD1B.unet_config["model_channels"], - channel_mult=(1, 2, 4), - num_res_blocks=(2, 2, 2), - **kwargs, + self, + context_dim=SSD1B.unet_config["context_dim"], + input_channels=4, + y_dim=SSD1B.unet_config["adm_in_channels"], + hidden_size=SSD1B.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), + **kwargs, ): super().__init__( context_dim, @@ -288,14 +287,14 @@ def __init__( class Segmind_Vega_TRT(UNetTRT): def __init__( - self, - context_dim=Segmind_Vega.unet_config["context_dim"], - input_channels=4, - y_dim=Segmind_Vega.unet_config["adm_in_channels"], - hidden_size=Segmind_Vega.unet_config["model_channels"], - channel_mult=(1, 2, 4), - num_res_blocks=(2, 2, 2), # TODO - **kwargs, + self, + context_dim=Segmind_Vega.unet_config["context_dim"], + input_channels=4, + y_dim=Segmind_Vega.unet_config["adm_in_channels"], + hidden_size=Segmind_Vega.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), # TODO + **kwargs, ): super().__init__( context_dim, @@ -310,14 +309,14 @@ def __init__( class KOALA_700M_TRT(UNetTRT): def __init__( - self, - context_dim=KOALA_700M.unet_config["context_dim"], - input_channels=4, - y_dim=KOALA_700M.unet_config["adm_in_channels"], - hidden_size=KOALA_700M.unet_config["model_channels"], - channel_mult=(1, 2, 4), - num_res_blocks=(2, 2, 2), # TODO - **kwargs, + self, + context_dim=KOALA_700M.unet_config["context_dim"], + input_channels=4, + y_dim=KOALA_700M.unet_config["adm_in_channels"], + hidden_size=KOALA_700M.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), # TODO + **kwargs, ): super().__init__( context_dim, @@ -332,14 +331,14 @@ def __init__( class KOALA_1B_TRT(UNetTRT): def __init__( - self, - context_dim=KOALA_1B.unet_config["context_dim"], - input_channels=4, - y_dim=KOALA_1B.unet_config["adm_in_channels"], - hidden_size=KOALA_1B.unet_config["model_channels"], - channel_mult=(1, 2, 4), - num_res_blocks=(2, 2, 2), # TODO - **kwargs, + self, + context_dim=KOALA_1B.unet_config["context_dim"], + input_channels=4, + y_dim=KOALA_1B.unet_config["adm_in_channels"], + hidden_size=KOALA_1B.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), # TODO + **kwargs, ): super().__init__( context_dim, @@ -354,14 +353,14 @@ def __init__( class SVD_img2vid_TRT(UNetTRT): def __init__( - self, - context_dim=SVD_img2vid.unet_config["context_dim"], - input_channels=8, - y_dim=SVD_img2vid.unet_config["adm_in_channels"], - hidden_size=SVD_img2vid.unet_config["model_channels"], - channel_mult=(1, 2, 4, 4), - num_res_blocks=(2, 2, 2, 2), - **kwargs, + self, + context_dim=SVD_img2vid.unet_config["context_dim"], + input_channels=8, + y_dim=SVD_img2vid.unet_config["adm_in_channels"], + hidden_size=SVD_img2vid.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, ): super().__init__( context_dim, @@ -378,14 +377,14 @@ def __init__( class SD15_instructpix2pix_TRT(UNetTRT): def __init__( - self, - context_dim=SD15_instructpix2pix.unet_config["context_dim"], - input_channels=8, - y_dim=0, - hidden_size=SD15_instructpix2pix.unet_config["model_channels"], - channel_mult=(1, 2, 4, 4), - num_res_blocks=(2, 2, 2, 2), - **kwargs, + self, + context_dim=SD15_instructpix2pix.unet_config["context_dim"], + input_channels=8, + y_dim=0, + hidden_size=SD15_instructpix2pix.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, ): super().__init__( context_dim, @@ -400,14 +399,14 @@ def __init__( class SDXL_instructpix2pix_TRT(UNetTRT): def __init__( - self, - context_dim=SDXL_instructpix2pix.unet_config["context_dim"], - input_channels=8, - y_dim=SDXL_instructpix2pix.unet_config["adm_in_channels"], - hidden_size=SDXL_instructpix2pix.unet_config["model_channels"], - channel_mult=(1, 2, 4), - num_res_blocks=(2, 2, 2), - **kwargs, + self, + context_dim=SDXL_instructpix2pix.unet_config["context_dim"], + input_channels=8, + y_dim=SDXL_instructpix2pix.unet_config["adm_in_channels"], + hidden_size=SDXL_instructpix2pix.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), + **kwargs, ): super().__init__( context_dim, diff --git a/models/supported_models.py b/models/supported_models.py index 607c8fa..bfd6c19 100644 --- a/models/supported_models.py +++ b/models/supported_models.py @@ -1,6 +1,8 @@ +import comfy.supported_models + +from .auraflow import AuraFlow_TRT from .baseline import TRTModelUtil from .flux import Flux_TRT, FluxSchnell_TRT -from .auraflow import AuraFlow_TRT from .sd3 import SD3_TRT from .sd_unet import ( SD15_TRT, @@ -16,7 +18,6 @@ Segmind_Vega_TRT, SVD_img2vid_TRT, ) -import comfy.supported_models supported_models = { "SD15": SD15_TRT, diff --git a/onnx_nodes.py b/onnx_nodes.py index 1ae9969..0414f1e 100644 --- a/onnx_nodes.py +++ b/onnx_nodes.py @@ -1,10 +1,12 @@ import os -from .onnx_utils.export import export_onnx + import comfy import folder_paths +from .onnx_utils.export import export_onnx + -class ONNX_EXPORT: +class ONNXExport: def __init__(self) -> None: pass @@ -65,7 +67,7 @@ def select_onnx_model(self, model_name): NODE_CLASS_MAPPING = { - "ONNX_EXPORT": ONNX_EXPORT, + "ONNX_EXPORT": ONNXExport, "ONNXModelSelector": ONNXModelSelector, } diff --git a/onnx_utils/export.py b/onnx_utils/export.py index 1532d98..74882f8 100644 --- a/onnx_utils/export.py +++ b/onnx_utils/export.py @@ -1,14 +1,16 @@ -import onnx -import torch -import numpy as np -from onnx import numpy_helper import json -import comfy import os +import time from typing import List -from onnx.external_data_helper import _get_all_tensors, ExternalDataInfo + +import comfy import folder_paths -import time +import numpy as np +import onnx +import torch +from onnx import numpy_helper +from onnx.external_data_helper import _get_all_tensors, ExternalDataInfo + from ..models import detect_version_from_model, get_helper_from_model @@ -22,7 +24,7 @@ def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]: ExternalDataInfo(tensor).location for tensor in model_tensors if tensor.HasField("data_location") - and tensor.data_location == onnx.TensorProto.EXTERNAL + and tensor.data_location == onnx.TensorProto.EXTERNAL ] return model_tensors_ext @@ -130,8 +132,8 @@ def init_hash_map(onnx_opt_model): wt_t_hash = hash(np.transpose(wt).data.tobytes()) for initializer_name, ( - initializer_hash, - initializer_shape, + initializer_hash, + initializer_shape, ) in initializer_hash_mapping.items(): # Due to constant folding, some weights are transposed during export # To account for the transpose op, we compare the initializer hash to the @@ -162,13 +164,13 @@ def init_hash_map(onnx_opt_model): def export_onnx( - model, - path, - batch_size: int = 1, - height: int = 512, - width: int = 512, - num_video_frames: int = 14, - context_multiplier: int = 1, + model, + path, + batch_size: int = 1, + height: int = 512, + width: int = 512, + num_video_frames: int = 14, + context_multiplier: int = 1, ): model_version = detect_version_from_model(model) model_helper = get_helper_from_model(model) @@ -194,7 +196,7 @@ def export_onnx( model, model_version, input_names, num_video_frames, model_helper.use_control ) - dir, name = os.path.split(path) + _, name = os.path.split(path) temp_path = os.path.join( folder_paths.get_temp_directory(), "{}".format(time.time()) ) diff --git a/pyproject.toml b/pyproject.toml index e7adb94..f2ff4f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ description = "TensorRT Node for ComfyUI\nThis node enables the best performance version = "0.1.8" license = "LICENSE" dependencies = [ - "tensorrt>=10.0.1", - "onnx" + "tensorrt>=10.6", + "onnx" ] [project.urls] diff --git a/tensorrt_diffusion_model.py b/tensorrt_diffusion_model.py index a0846f3..b3c4d4c 100644 --- a/tensorrt_diffusion_model.py +++ b/tensorrt_diffusion_model.py @@ -1,10 +1,12 @@ -import torch -import tensorrt as trt import os -from typing import Optional -from tqdm import tqdm from math import prod +from typing import Optional + import comfy.model_management +import tensorrt as trt +import torch +from tqdm import tqdm + from .models import get_model_from_version, TRTModelUtil trt.init_libnvinfer_plugins(None, "") @@ -31,6 +33,9 @@ def __init__(self, model_helper: TRTModelUtil, *args, **kwargs) -> None: self.engine = None self.model = model_helper + for k, v in self.model.model_attributes().items(): + setattr(self, k, v) + self.dtype = self.model.get_dtype() self.device = comfy.model_management.get_torch_device() @@ -45,7 +50,8 @@ def __init__(self, model_helper: TRTModelUtil, *args, **kwargs) -> None: self.extra_inputs: dict[str, torch.Tensor] = {} # Sets up the builder to use the timing cache file, and creates it if it does not already exist - def _setup_timing_cache(self, config: trt.IBuilderConfig, timing_cache_path: str): + @staticmethod + def _setup_timing_cache(config: trt.IBuilderConfig, timing_cache_path: str): buffer = b"" if os.path.exists(timing_cache_path): with open(timing_cache_path, mode="rb") as timing_cache_file: @@ -57,7 +63,8 @@ def _setup_timing_cache(self, config: trt.IBuilderConfig, timing_cache_path: str config.set_timing_cache(timing_cache, ignore_mismatch=True) # Saves the config's timing cache to file - def _save_timing_cache(self, config: trt.IBuilderConfig, timing_cache_path: str): + @staticmethod + def _save_timing_cache(config: trt.IBuilderConfig, timing_cache_path: str): timing_cache: trt.ITimingCache = config.get_timing_cache() with open(timing_cache_path, "wb") as timing_cache_file: timing_cache_file.write(memoryview(timing_cache.serialize())) @@ -82,19 +89,17 @@ def _create_profile(self, builder, min_config, opt_config, max_config): return profile def build( - self, - onnx_path: str, - engine_path: str, - timing_cache_path: str, - opt_config: dict, - min_config: Optional[dict] = None, - max_config: Optional[dict] = None, + self, + onnx_path: str, + engine_path: str, + timing_cache_path: str, + opt_config: dict, + min_config: Optional[dict] = None, + max_config: Optional[dict] = None, ) -> bool: comfy.model_management.unload_all_models() comfy.model_management.soft_empty_cache() - # TRT conversion starts here - logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) network = builder.create_network( @@ -110,10 +115,12 @@ def build( return False config = builder.create_builder_config() + # config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED self._setup_timing_cache(config, timing_cache_path) config.progress_monitor = TQDMProgressMonitor() profile = self._create_profile(builder, min_config, opt_config, max_config) config.add_optimization_profile(profile) + config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) # STRIP_PLAN self.engine = builder.build_serialized_network(network, config) @@ -184,7 +191,6 @@ def setup_tensors(self, model_inputs): width *= 8 _, context_len, _ = model_inputs["context"].shape min_batch = dims[0][0] - opt_batch = dims[1][0] max_batch = dims[2][0] # Split batch if our batch is bigger than the max batch size the trt engine supports for i in range(max_batch, min_batch - 1, -1): @@ -226,14 +232,14 @@ def setup_tensors(self, model_inputs): @torch.cuda.nvtx.range("__call__") def __call__( - self, - x, - timesteps, - context, - y=None, - control=None, - transformer_options=None, - **kwargs, + self, + x, + timesteps, + context, + y=None, + control=None, + transformer_options=None, + **kwargs, ): model_inputs = {"x": x, "timesteps": timesteps, "context": context} @@ -245,9 +251,9 @@ def __call__( for i, tensor in enumerate(control_tensors): model_inputs[f"{control_layer}_control_{i}"] = tensor - for k, v in kwargs.items(): - # TODO actually needed? model_inputs[k] = v - pass + # TODO actually needed? + # for k, v in kwargs.items(): + # model_inputs[k] = v if self.current_shape != x.shape: self.setup_tensors(model_inputs) @@ -268,7 +274,7 @@ def __call__( for i in range(self.curr_split_batch): for k, v in model_inputs.items(): self.context.set_tensor_address( - k, v[(v.shape[0] // self.curr_split_batch) * i :].data_ptr() + k, v[(v.shape[0] // self.curr_split_batch) * i:].data_ptr() ) for k in self.extra_inputs.keys(): self.context.set_tensor_address(k, self.zero_pool.data_ptr()) @@ -289,22 +295,22 @@ def phase_start(self, phase_name, parent_phase, num_steps): leave = False try: if parent_phase is not None: - nbIndents = ( - self._active_phases.get(parent_phase, {}).get( - "nbIndents", self.max_indent - ) - + 1 + nb_indents = ( + self._active_phases.get(parent_phase, {}).get( + "nbIndents", self.max_indent + ) + + 1 ) - if nbIndents >= self.max_indent: + if nb_indents >= self.max_indent: return else: - nbIndents = 0 + nb_indents = 0 leave = True self._active_phases[phase_name] = { "tq": tqdm( - total=num_steps, desc=phase_name, leave=leave, position=nbIndents + total=num_steps, desc=phase_name, leave=leave, position=nb_indents ), - "nbIndents": nbIndents, + "nbIndents": nb_indents, "parent_phase": parent_phase, } except KeyboardInterrupt: @@ -326,8 +332,8 @@ def phase_finish(self, phase_name): "parent_phase", None ) if ( - self._active_phases[phase_name]["parent_phase"] - in self._active_phases.keys() + self._active_phases[phase_name]["parent_phase"] + in self._active_phases.keys() ): self._active_phases[ self._active_phases[phase_name]["parent_phase"] diff --git a/tensorrt_nodes.py b/tensorrt_nodes.py index 8b88d5e..b41be31 100644 --- a/tensorrt_nodes.py +++ b/tensorrt_nodes.py @@ -38,7 +38,8 @@ def INPUT_TYPES(s): FUNCTION = "load_unet" CATEGORY = "TensorRT" - def load_unet(self, unet_name, model_type): + @staticmethod + def load_unet(unet_name, model_type): unet_path = folder_paths.get_full_path("tensorrt", unet_name) model = TRTDiffusionBackbone.load_trt_model(unet_path, model_type) return ( @@ -50,7 +51,7 @@ def load_unet(self, unet_name, model_type): ) -class TRT_MODEL_CONVERSION_BASE: +class TRTBuildBase: def __init__(self): self.output_dir = folder_paths.get_output_directory() self.temp_dir = folder_paths.get_temp_directory() @@ -72,24 +73,24 @@ def INPUT_TYPES(s): raise NotImplementedError def _convert( - self, - model, - filename_prefix, - batch_size_min, - batch_size_opt, - batch_size_max, - height_min, - height_opt, - height_max, - width_min, - width_opt, - width_max, - context_min, - context_opt, - context_max, - num_video_frames, - is_static: bool, - output_onnx: Optional[str] = None, + self, + model, + filename_prefix, + batch_size_min, + batch_size_opt, + batch_size_max, + height_min, + height_opt, + height_max, + width_min, + width_opt, + width_max, + context_min, + context_opt, + context_max, + num_video_frames, + is_static: bool, + output_onnx: Optional[str] = None, ): if output_onnx is None: output_onnx = os.path.normpath( @@ -187,9 +188,9 @@ def _convert( return () -class DYNAMIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE): +class DynamicTRTBuild(TRTBuildBase): def __init__(self): - super(DYNAMIC_TRT_MODEL_CONVERSION, self).__init__() + super(DynamicTRTBuild, self).__init__() @classmethod def INPUT_TYPES(s): @@ -321,23 +322,23 @@ def INPUT_TYPES(s): } def convert( - self, - model, - filename_prefix, - batch_size_min, - batch_size_opt, - batch_size_max, - height_min, - height_opt, - height_max, - width_min, - width_opt, - width_max, - context_min, - context_opt, - context_max, - num_video_frames, - onnx_model_path, + self, + model, + filename_prefix, + batch_size_min, + batch_size_opt, + batch_size_max, + height_min, + height_opt, + height_max, + width_min, + width_opt, + width_max, + context_min, + context_opt, + context_max, + num_video_frames, + onnx_model_path, ): return super()._convert( model, @@ -360,9 +361,9 @@ def convert( ) -class STATIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE): +class StaticTRTBuild(TRTBuildBase): def __init__(self): - super(STATIC_TRT_MODEL_CONVERSION, self).__init__() + super(StaticTRTBuild, self).__init__() @classmethod def INPUT_TYPES(s): @@ -422,15 +423,15 @@ def INPUT_TYPES(s): } def convert( - self, - model, - filename_prefix, - batch_size_opt, - height_opt, - width_opt, - context_opt, - num_video_frames, - onnx_model_path, + self, + model, + filename_prefix, + batch_size_opt, + height_opt, + width_opt, + context_opt, + num_video_frames, + onnx_model_path, ): return super()._convert( model, @@ -455,8 +456,8 @@ def convert( NODE_CLASS_MAPPINGS = { "TensorRTLoader": TensorRTLoader, - "DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION, - "STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION, + "DYNAMIC_TRT_MODEL_CONVERSION": DynamicTRTBuild, + "STATIC_TRT_MODEL_CONVERSION": StaticTRTBuild, } NODE_DISPLAY_NAME_MAPPINGS = { "TensorRTLoader": "TensorRT Loader",