diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index f10f0c4..e669db7 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -17,5 +17,5 @@ jobs: - name: Publish Custom Node uses: Comfy-Org/publish-node-action@main with: - ## Add your own personal access token to your Github Repository secrets and reference it here. + ## Add your own personal access token to your GitHub Repository secrets and reference it here. personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} \ No newline at end of file diff --git a/README.md b/README.md index 63df71a..69af30e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # TensorRT Node for ComfyUI -This node enables the best performance on NVIDIA RTX™ Graphics Cards - (GPUs) for Stable Diffusion by leveraging NVIDIA TensorRT. +This node enables the best performance on NVIDIA RTX™ Graphics Cards(GPUs) for Stable Diffusion by leveraging NVIDIA +TensorRT. Supports: @@ -11,8 +11,7 @@ Supports: - SDXL - SDXL Turbo - Stable Video Diffusion -- Stable Video Diffusion-XT  -- AuraFlow +- Stable Video Diffusion-XT- AuraFlow - Flux Requirements: @@ -31,7 +30,8 @@ Requirements: The recommended way to install these nodes is to use the [ComfyUI Manager](https://github.com/ltdrdata/ComfyUI-Manager) to easily install them to your ComfyUI instance. -You can also manually install them by git cloning the repo to your ComfyUI/custom_nodes folder and installing the requirements like: +You can also manually install them by git cloning the repo to your ComfyUI/custom_nodes folder and installing the +requirements like: ``` cd custom_nodes @@ -62,7 +62,7 @@ You have the option to build either dynamic or static TensorRT engines: Note: Most users will prefer dynamic engines, but static engines can be useful if you use a specific resolution + batch size combination most of the time. Static engines also require less VRAM; the wider the dynamic -range, the more VRAM will be consumed. +range, the more VRAM will be consumed. ## Instructions @@ -71,19 +71,19 @@ These .json files can be loaded in ComfyUI. ### Building A TensorRT Engine From a Checkpoint -1. Add a Load Checkpoint Node -2. Add either a Static Model TensorRT Conversion node or a Dynamic - Model TensorRT Conversion node to ComfyUI -3. ![](readme_images/image3.png) -4. Connect the Load Checkpoint Model output to the TensorRT Conversion - Node Model input. -5. ![](readme_images/image5.png) -6. ![](readme_images/image2.png) -7. To help identify the converted TensorRT model, provide a meaningful - filename prefix, add this filename after “tensorrt/” -8. ![](readme_images/image9.png) - -9. Click on Queue Prompt to start building the TensorRT Engines +1. Add a Load Checkpoint Node +2. Add either a Static Model TensorRT Conversion node or a Dynamic + Model TensorRT Conversion node to ComfyUI +3. ![](readme_images/image3.png) +4. Connect the Load Checkpoint Model output to the TensorRT Conversion + Node Model input. +5. ![](readme_images/image5.png) +6. ![](readme_images/image2.png) +7. To help identify the converted TensorRT model, provide a meaningful + filename prefix, add this filename after “tensorrt/” +8. ![](readme_images/image9.png) + +9. Click on Queue Prompt to start building the TensorRT Engines 10. ![](readme_images/image7.png) ![](readme_images/image11.png) @@ -96,9 +96,9 @@ the console. ![](readme_images/image4.png) -The first time generating an engine for a checkpoint will take awhile. +The first time generating an engine for a checkpoint will take a while. Additional engines generated thereafter for the same checkpoint will be -much faster. Generating engines can take anywhere from 3-10 minutes for +much faster. Generating engines can take anywhere from 3-10 minutes for the image generation models and 10-25 minutes for SVD. SVD-XT is an extremely extensive model - engine build times may take up to an hour. @@ -115,33 +115,33 @@ TensorRT Engines are loaded using the TensorRT Loader node. ComfyUI TensorRT engines are not yet compatible with ControlNets or LoRAs. Compatibility will be enabled in a future update. -1. Add a TensorRT Loader node -2. Note, if a TensorRT Engine has been created during a ComfyUI - session, it will not show up in the TensorRT Loader until the - ComfyUI interface has been refreshed (F5 to refresh browser). -3. ![](readme_images/image6.png) -4. Select a TensorRT Engine from the unet_name dropdown -5. Dynamic Engines will use a filename format of: +1. Add a TensorRT Loader node +2. Note, if a TensorRT Engine has been created during a ComfyUI + session, it will not show up in the TensorRT Loader until the + ComfyUI interface has been refreshed (F5 to refresh browser). +3. ![](readme_images/image6.png) +4. Select a TensorRT Engine from the unet_name dropdown +5. Dynamic Engines will use a filename format of:   -1. dyn-b-min-max-opt-h-min-max-opt-w-min-max-opt -2. dyn=dynamic, b=batch size, h=height, w=width +1. dyn-b-min-max-opt-h-min-max-opt-w-min-max-opt +2. dyn=dynamic, b=batch size, h=height, w=width   -6. Static Engine will use a filename format of: +6. Static Engine will use a filename format of:   -1. stat-b-opt-h-opt-w-opt -2. stat=static, b=batch size, h=height, w=width +1. stat-b-opt-h-opt-w-opt +2. stat=static, b=batch size, h=height, w=width   -7. ![](readme_images/image8.png) -8. The model_type must match the model type of the TensorRT engine. -9. ![](readme_images/image10.png) +7. ![](readme_images/image8.png) +8. The model_type must match the model type of the TensorRT engine. +9. ![](readme_images/image10.png) 10. The CLIP and VAE for the workflow will need to be utilized from the original model checkpoint, the MODEL output from the TensorRT Loader will be connected to the Sampler. diff --git a/__init__.py b/__init__.py index 468b3d5..cd30b14 100644 --- a/__init__.py +++ b/__init__.py @@ -1,11 +1,9 @@ -from .tensorrt_convert import DYNAMIC_TRT_MODEL_CONVERSION -from .tensorrt_convert import STATIC_TRT_MODEL_CONVERSION -from .tensorrt_loader import TrTUnet -from .tensorrt_loader import TensorRTLoader +from .onnx_nodes import NODE_CLASS_MAPPING as ONNX_CLASS_MAP +from .onnx_nodes import NODE_DISPLAY_NAME_MAPPINGS as ONNX_NAME_MAP +from .tensorrt_nodes import NODE_CLASS_MAPPINGS as TRT_CLASS_MAP +from .tensorrt_nodes import NODE_DISPLAY_NAME_MAPPINGS as TRT_NAME_MAP -NODE_CLASS_MAPPINGS = { "DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION, "STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION, "TensorRTLoader": TensorRTLoader } +NODE_CLASS_MAPPINGS = TRT_CLASS_MAP | ONNX_CLASS_MAP +NODE_DISPLAY_NAME_MAPPINGS = TRT_NAME_MAP | ONNX_NAME_MAP - -NODE_DISPLAY_NAME_MAPPINGS = { "DYNAMIC_TRT_MODEL_CONVERSION": "DYNAMIC TRT_MODEL CONVERSION", "STATIC TRT_MODEL CONVERSION": STATIC_TRT_MODEL_CONVERSION, "TensorRTLoader": "TensorRT Loader" } - -__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] \ No newline at end of file +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..52d25c3 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,9 @@ +from .baseline import TRTModelUtil +from .supported_models import ( + supported_models, + unsupported_models, + detect_version_from_model, + get_helper_from_version, + get_helper_from_model, + get_model_from_version, +) diff --git a/models/auraflow.py b/models/auraflow.py new file mode 100644 index 0000000..9becebc --- /dev/null +++ b/models/auraflow.py @@ -0,0 +1,23 @@ +from .baseline import TRTModelUtil + + +class AuraFlow_TRT(TRTModelUtil): + def __init__( + self, context_dim=2048, input_channels=4, context_len=256, **kwargs + ) -> None: + super().__init__( + context_dim=context_dim, + input_channels=input_channels, + context_len=context_len, + **kwargs, + ) + self.is_conditional = True + + @classmethod + def from_model(cls, model, **kwargs): + return cls( + context_dim=model.model.model_config.unet_config["cond_seq_dim"], + input_channels=model.model.diffusion_model.out_channels, + use_control=False, + **kwargs, + ) diff --git a/models/baseline.py b/models/baseline.py new file mode 100644 index 0000000..dfebe59 --- /dev/null +++ b/models/baseline.py @@ -0,0 +1,105 @@ +import torch + + +class TRTModelUtil: + def __init__( + self, + context_dim: int, + input_channels: int, + context_len: int, + use_control: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.context_dim = context_dim + self.input_channels = input_channels + self.context_len = context_len + self.use_control = use_control + self.is_conditional = False + + self.input_config = { + "x": { + "batch": "{batch_size}", + "input_channels": self.input_channels, + "height": "{height}//8", + "width": "{width}//8", + }, + "timesteps": { + "batch": "{batch_size}", + }, + "context": { + "batch": "{batch_size}", + "context_len": "{context_len}", + "context_dim": self.context_dim, + }, + } + + self.output_config = { + "h": { + "batch": "{batch_size}", + "input_channels": self.input_channels, + "height": "{height}//8", + "width": "{width}//8", + } + } + + def to_dict(self): + return { + "context_dim": self.context_dim, + "input_channels": self.input_channels, + "context_len": self.context_dim, + "use_control": self.use_control, + } + + def get_input_names(self) -> list[str]: + return list(self.input_config.keys()) + + def get_output_names(self) -> list[str]: + return list(self.output_config.keys()) + + def get_dtype(self) -> torch.dtype: + return torch.float16 + + def get_input_shapes(self, **kwargs) -> dict: + inputs_shapes = {} + for io_name, io_config in self.input_config.items(): + _inp = self._eval_shape(io_config, **kwargs) + inputs_shapes[io_name] = _inp + + return inputs_shapes + + def get_input_shapes_by_key(self, key: str, **kwargs) -> tuple[int]: + return self._eval_shape(self.input_config[key], **kwargs) + + def get_dynamic_axes(self, config: dict = {}) -> dict: + dynamic_axes = {} + + if config == {}: + config = self.input_config | self.output_config + for k, v in config.items(): + dyn = {i: ax for i, (ax, s) in enumerate(v.items()) if isinstance(s, str)} + dynamic_axes[k] = dyn + + return dynamic_axes + + def _eval_shape(self, inp, **kwargs) -> tuple[int]: + if "context_len" not in kwargs: + kwargs["context_len"] = self.context_len + shape = [] + for _, v in inp.items(): + _s = v + if isinstance(v, str): + _s = int(eval(v.format(**kwargs))) + shape.append(_s) + return tuple(shape) + + def get_control(self, *args, **kwargs) -> dict: + raise NotImplementedError + + @classmethod + def from_model(cls, model, **kwargs): + raise NotImplementedError + + def model_attributes(self, **kwargs) -> dict: + return {} diff --git a/models/flux.py b/models/flux.py new file mode 100644 index 0000000..31d273b --- /dev/null +++ b/models/flux.py @@ -0,0 +1,133 @@ +import torch + +from .baseline import TRTModelUtil + + +class FLuxBase(TRTModelUtil): + def __init__( + self, + context_dim: int, + input_channels: int, + y_dim: int, + hidden_size: int, + double_blocks: int, + single_blocks: int, + *args, + **kwargs, + ) -> None: + super().__init__(context_dim, input_channels, 256, *args, **kwargs) + + self.hidden_size = hidden_size + self.y_dim = y_dim + self.single_blocks = single_blocks + self.double_blocks = double_blocks + + self.extra_input = { + "guidance": {"batch": "{batch_size}"}, + "y": {"batch": "{batch_size}", "y_dim": y_dim}, + } + + self.input_config.update(self.extra_input) + + if self.use_control: + self.control = self.get_control() + self.input_config.update(self.control) + + def to_dict(self): + return { + "context_dim": self.context_dim, + "input_channels": self.input_channels, + "y_dim": self.y_dim, + "hidden_size": self.hidden_size, + "double_blocks": self.double_blocks, + "single_blocks": self.single_blocks, + "use_control": self.use_control, + } + + def get_control(self): + control_input = {} + for i in range(self.double_blocks): + control_input[f"input_control_{i}"] = { + "batch": "{batch_size}", + "ids": "({height}*{width}//(8*2)**2)", + "hidden_size": self.hidden_size, + } + for i in range(self.single_blocks): + control_input[f"output_control_{i}"] = { + "batch": "{batch_size}", + "ids": "({height}*{width}//(8*2)**2)", + "hidden_size": self.hidden_size, + } + return control_input + + def get_dtype(self): + return torch.bfloat16 + + @classmethod + def from_model(cls, model, **kwargs): + return cls( + context_dim=model.model.model_config.unet_config["context_in_dim"], + input_channels=model.model.model_config.unet_config["in_channels"], + hidden_size=model.model.model_config.unet_config["hidden_size"], + y_dim=model.model.model_config.unet_config["vec_in_dim"], + double_blocks=model.model.model_config.unet_config["depth"], + single_blocks=model.model.model_config.unet_config["depth_single_blocks"], + **kwargs, + ) + + def model_attributes(self, **kwargs) -> dict: + return { + "double_blocks": [None, ] * self.double_blocks, + } + + +class Flux_TRT(FLuxBase): + def __init__( + self, + context_dim=4096, + input_channels=16, + y_dim=768, + hidden_size=3072, + double_blocks=19, + single_blocks=28, + **kwargs, + ): + super().__init__( + context_dim=context_dim, + input_channels=input_channels, + y_dim=y_dim, + hidden_size=hidden_size, + double_blocks=double_blocks, + single_blocks=single_blocks, + **kwargs, + ) + + @classmethod + def from_model(cls, model, **kwargs): + return super(Flux_TRT, cls).from_model(model, use_control=True) + + +class FluxSchnell_TRT(FLuxBase): + def __init__( + self, + context_dim=4096, + input_channels=16, + y_dim=768, + hidden_size=3072, + double_blocks=19, + single_blocks=28, + **kwargs, + ): + super().__init__( + context_dim=context_dim, + input_channels=input_channels, + y_dim=y_dim, + hidden_size=hidden_size, + double_blocks=double_blocks, + single_blocks=single_blocks, + **kwargs, + ) + + @classmethod + def from_model(cls, model, **kwargs): + return super(FluxSchnell_TRT, cls).from_model(model, use_control=True) diff --git a/models/sd3.py b/models/sd3.py new file mode 100644 index 0000000..fa2521b --- /dev/null +++ b/models/sd3.py @@ -0,0 +1,68 @@ +import torch + +from .baseline import TRTModelUtil + + +class SD3_TRT(TRTModelUtil): + def __init__( + self, + context_dim: int = 4096, + input_channels: int = 16, + y_dim: int = 2048, + hidden_size: int = 1536, + output_blocks: int = 24, + *args, + **kwargs, + ) -> None: + super().__init__(context_dim, input_channels, 77, *args, **kwargs) + + self.hidden_size = hidden_size + self.y_dim = y_dim + self.is_conditional = True + self.output_blocks = output_blocks # - 1 # self.joint_blocks + + self.extra_input = { + "y": {"batch": "{batch_size}", "y_dim": y_dim}, + } + + self.input_config.update(self.extra_input) + + if self.use_control: + self.control = self.get_control() + self.input_config.update(self.control) + + def to_dict(self): + return { + "context_dim": self.context_dim, + "input_channels": self.input_channels, + "hidden_size": self.hidden_size, + "y_dim": self.y_dim, + "output_blocks": self.output_blocks, + "use_control": self.use_control, + } + + def get_control(self): + control_input = {} + for i in range(self.output_blocks): + control_input[f"output_control_{i}"] = { + "batch": "{batch_size}", + "ids": "({height}*{width}//(8*2)**2)", + "hidden_size": self.hidden_size, + } + + return control_input + + def get_dtype(self): + return torch.float16 + + @classmethod + def from_model(cls, model, **kwargs): + return cls( + context_dim=model.model.diffusion_model.context_embedder.in_features, + input_channels=model.model.diffusion_model.in_channels, + hidden_size=model.model.diffusion_model.context_embedder.out_features, + y_dim=model.model.model_config.unet_config.get("adm_in_channels", 0), + output_blocks=model.model.diffusion_model.depth, + use_control=True, + **kwargs, + ) diff --git a/models/sd_unet.py b/models/sd_unet.py new file mode 100644 index 0000000..d7d42a2 --- /dev/null +++ b/models/sd_unet.py @@ -0,0 +1,419 @@ +import torch +from comfy.supported_models import ( + SD15, + SD20, + SD21UnclipL, + SD21UnclipH, + SDXLRefiner, + SDXL, + SSD1B, + Segmind_Vega, + KOALA_700M, + KOALA_1B, + SVD_img2vid, + SD15_instructpix2pix, + SDXL_instructpix2pix, +) + +from .baseline import TRTModelUtil + + +class UNetTRT(TRTModelUtil): + def __init__( + self, + context_dim: int, + input_channels: int, + y_dim: int, + hidden_size: int, + channel_mult: tuple[int], + num_res_blocks: tuple[int], + context_len: int = 77, + *args, + **kwargs, + ) -> None: + super().__init__(context_dim, input_channels, context_len, *args, **kwargs) + + self.hidden_size = hidden_size + self.y_dim = y_dim + self.is_conditional = True + + self.channel_mult = channel_mult + self.num_res_blocks = num_res_blocks + self.input_block_chans = self.set_block_chans() + + if self.y_dim: + self.input_config.update({"y": {"batch": "{batch_size}", "y_dim": y_dim}}) + + if self.use_control: + self.control = self.get_control() + self.input_config.update(self.control) + + def set_block_chans(self): + ch = self.hidden_size + ds = 1 + + input_block_chans = [(self.hidden_size, ds)] + for level, mult in enumerate(self.channel_mult): + for nr in range(self.num_res_blocks[level]): + ch = mult * self.hidden_size + input_block_chans.append((ch, ds)) + if level != len(self.channel_mult) - 1: + out_ch = ch + ch = out_ch + ds *= 2 + input_block_chans.append((ch, ds)) + return input_block_chans + + @classmethod + def from_model(cls, model, **kwargs): + return cls( + context_dim=model.model.model_config.unet_config["context_dim"], + input_channels=model.model.diffusion_model.in_channels, + hidden_size=model.model.model_config.unet_config["model_channels"], + y_dim=model.model.model_config.unet_config.get("adm_in_channels", 0), + channel_mult=model.model.diffusion_model.channel_mult, + num_res_blocks=model.model.diffusion_model.num_res_blocks, + **kwargs, + ) + + def to_dict(self): + return { + "context_dim": self.context_dim, + "input_channels": self.input_channels, + "hidden_size": self.hidden_size, + "y_dim": self.y_dim, + "channel_mult": self.channel_mult, + "num_res_blocks": self.num_res_blocks, + "use_control": self.use_control, + } + + def get_control(self): + control_input = {} + + for i, (ch, d) in enumerate(reversed(self.input_block_chans)): + control_input[f"input_control_{i}"] = { + "batch": "{batch_size}", + "chn": ch, + f"height{d}": "{height}//(8*" + str(d) + ")", + f"width{d}": "{width}//(8*" + str(d) + ")", + } + + for i, (ch, d) in enumerate(self.input_block_chans): + control_input[f"output_control_{i}"] = { + "batch": "{batch_size}", + "chn": ch, + f"height{d}": "{height}//(8*" + str(d) + ")", + f"width{d}": "{width}//(8*" + str(d) + ")", + } + + ch, d = self.input_block_chans[-1] + control_input["middle_control_0"] = { + "batch": "{batch_size}", + "chn": ch, + f"height{d}": "{height}//(8*" + str(d) + ")", + f"width{d}": "{width}//(8*" + str(d) + ")", + } + return control_input + + def get_dtype(self): + return torch.float16 + + +class SD15_TRT(UNetTRT): + def __init__( + self, + context_dim=SD15.unet_config["context_dim"], + input_channels=4, + y_dim=0, + hidden_size=SD15.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + @classmethod + def from_model(cls, model, **kwargs): + return super(SD15_TRT, cls).from_model(model, use_control=True) + + +class SD20_TRT(UNetTRT): + def __init__( + self, + context_dim=SD20.unet_config["context_dim"], + input_channels=4, + y_dim=0, + hidden_size=SD20.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + @classmethod + def from_model(cls, model, **kwargs): + return super(SD20_TRT, cls).from_model(model, use_control=True) + + +class SD21UnclipL_TRT(UNetTRT): + def __init__( + self, + context_dim=SD21UnclipL.unet_config["context_dim"], + input_channels=4, + y_dim=SD21UnclipL.unet_config["adm_in_channels"], + hidden_size=SD21UnclipL.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class SD21UnclipH_TRT(UNetTRT): + def __init__( + self, + context_dim=SD21UnclipH.unet_config["context_dim"], + input_channels=4, + y_dim=SD21UnclipH.unet_config["adm_in_channels"], + hidden_size=SD21UnclipH.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class SDXLRefiner_TRT(UNetTRT): + def __init__( + self, + context_dim=SDXLRefiner.unet_config["context_dim"], + input_channels=4, + y_dim=SDXLRefiner.unet_config["adm_in_channels"], + hidden_size=SDXLRefiner.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class SDXL_TRT(UNetTRT): + def __init__( + self, + context_dim=SDXL.unet_config["context_dim"], + input_channels=4, + y_dim=SDXL.unet_config["adm_in_channels"], + hidden_size=SDXL.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + @classmethod + def from_model(cls, model, **kwargs): + return super(SDXL_TRT, cls).from_model(model, use_control=True) + + +class SSD1B_TRT(UNetTRT): + def __init__( + self, + context_dim=SSD1B.unet_config["context_dim"], + input_channels=4, + y_dim=SSD1B.unet_config["adm_in_channels"], + hidden_size=SSD1B.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class Segmind_Vega_TRT(UNetTRT): + def __init__( + self, + context_dim=Segmind_Vega.unet_config["context_dim"], + input_channels=4, + y_dim=Segmind_Vega.unet_config["adm_in_channels"], + hidden_size=Segmind_Vega.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), # TODO + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class KOALA_700M_TRT(UNetTRT): + def __init__( + self, + context_dim=KOALA_700M.unet_config["context_dim"], + input_channels=4, + y_dim=KOALA_700M.unet_config["adm_in_channels"], + hidden_size=KOALA_700M.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), # TODO + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class KOALA_1B_TRT(UNetTRT): + def __init__( + self, + context_dim=KOALA_1B.unet_config["context_dim"], + input_channels=4, + y_dim=KOALA_1B.unet_config["adm_in_channels"], + hidden_size=KOALA_1B.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), # TODO + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class SVD_img2vid_TRT(UNetTRT): + def __init__( + self, + context_dim=SVD_img2vid.unet_config["context_dim"], + input_channels=8, + y_dim=SVD_img2vid.unet_config["adm_in_channels"], + hidden_size=SVD_img2vid.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + self.input_config["context"]["context_len"] = self.context_len + + +class SD15_instructpix2pix_TRT(UNetTRT): + def __init__( + self, + context_dim=SD15_instructpix2pix.unet_config["context_dim"], + input_channels=8, + y_dim=0, + hidden_size=SD15_instructpix2pix.unet_config["model_channels"], + channel_mult=(1, 2, 4, 4), + num_res_blocks=(2, 2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) + + +class SDXL_instructpix2pix_TRT(UNetTRT): + def __init__( + self, + context_dim=SDXL_instructpix2pix.unet_config["context_dim"], + input_channels=8, + y_dim=SDXL_instructpix2pix.unet_config["adm_in_channels"], + hidden_size=SDXL_instructpix2pix.unet_config["model_channels"], + channel_mult=(1, 2, 4), + num_res_blocks=(2, 2, 2), + **kwargs, + ): + super().__init__( + context_dim, + input_channels, + y_dim, + hidden_size, + channel_mult, + num_res_blocks, + **kwargs, + ) diff --git a/models/supported_models.py b/models/supported_models.py new file mode 100644 index 0000000..bfd6c19 --- /dev/null +++ b/models/supported_models.py @@ -0,0 +1,83 @@ +import comfy.supported_models + +from .auraflow import AuraFlow_TRT +from .baseline import TRTModelUtil +from .flux import Flux_TRT, FluxSchnell_TRT +from .sd3 import SD3_TRT +from .sd_unet import ( + SD15_TRT, + SD20_TRT, + SD21UnclipH_TRT, + SD21UnclipL_TRT, + SDXL_instructpix2pix_TRT, + SDXLRefiner_TRT, + SDXL_TRT, + SSD1B_TRT, + KOALA_700M_TRT, + KOALA_1B_TRT, + Segmind_Vega_TRT, + SVD_img2vid_TRT, +) + +supported_models = { + "SD15": SD15_TRT, + "SD20": SD20_TRT, + "SD21UnclipL": SD21UnclipL_TRT, + "SD21UnclipH": SD21UnclipH_TRT, + "SDXL_instructpix2pix": SDXL_instructpix2pix_TRT, + "SDXLRefiner": SDXLRefiner_TRT, + "SDXL": SDXL_TRT, + "SSD1B": SSD1B_TRT, + "KOALA_700M": KOALA_700M_TRT, + "KOALA_1B": KOALA_1B_TRT, + "Segmind_Vega": Segmind_Vega_TRT, + "SVD_img2vid": SVD_img2vid_TRT, + "SD3": SD3_TRT, + "AuraFlow": AuraFlow_TRT, + "Flux": Flux_TRT, + "FluxSchnell": FluxSchnell_TRT, +} + +unsupported_models = [ + "SV3D_u", + "SV3D_p", + "Stable_Zero123", + "SD_X4Upscaler", + "Stable_Cascade_C", + "Stable_Cascade_B", + "StableAudio", + "HunyuanDiT", + "HunyuanDiT1", +] + + +def detect_version_from_model(model): + return model.model.model_config.__class__.__name__ + + +def get_helper_from_version(model_version: str, config: dict = {}) -> TRTModelUtil: + model_helper = supported_models.get(model_version, None) + if model_helper is None: + raise NotImplementedError("{} is not supported.".format(model_version)) + return model_helper(**config) + + +def get_helper_from_model(model) -> TRTModelUtil: + model_version = detect_version_from_model(model) + helper_cls = supported_models.get(model_version, TRTModelUtil) + return helper_cls.from_model(model) + + +def get_model_from_version(model_version: str, config: dict = {}): + conf = getattr(comfy.supported_models, model_version) + helper = get_helper_from_version(model_version, config) + conf.unet_config["disable_unet_model_creation"] = True + conf.unet_config["in_channels"] = helper.input_channels + conf = conf(conf.unet_config) + if model_version in ("SD20",): + model = comfy.model_base.BaseModel( + conf, model_type=comfy.model_base.ModelType.V_PREDICTION + ) + else: + model = conf.get_model({}) + return model, helper diff --git a/onnx_nodes.py b/onnx_nodes.py new file mode 100644 index 0000000..0414f1e --- /dev/null +++ b/onnx_nodes.py @@ -0,0 +1,77 @@ +import os + +import comfy +import folder_paths + +from .onnx_utils.export import export_onnx + + +class ONNXExport: + def __init__(self) -> None: + pass + + RETURN_TYPES = () + FUNCTION = "export" + OUTPUT_NODE = True + CATEGORY = "TensorRT" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "output_folder": ( + "STRING", + {"default": os.path.join(folder_paths.models_dir, "onnx")}, + ), + }, + "optional": {"filename": ("STRING", {"default": "model.onnx"})}, + } + + def export(self, model, output_folder, filename): + comfy.model_management.unload_all_models() + comfy.model_management.load_models_gpu( + [model], force_patch_weights=True, force_full_load=True + ) + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + path = os.path.join(output_folder, filename) + export_onnx(model, path) + print(f"INFO: Exported Model to: {path}") + return () + + +class ONNXModelSelector: + @classmethod + def INPUT_TYPES(s): + onnx_path = os.path.join(folder_paths.models_dir, "onnx") + if not os.path.exists(onnx_path): + os.makedirs(onnx_path) + onnx_models = [f for f in os.listdir(onnx_path) if f.endswith(".onnx")] + return { + "required": { + "model_name": (onnx_models,), + }, + } + + RETURN_TYPES = ("STRING", "STRING") + RETURN_NAMES = ("model_path", "model_name") + FUNCTION = "select_onnx_model" + CATEGORY = "TensorRT" + + def select_onnx_model(self, model_name): + onnx_path = os.path.join(folder_paths.models_dir, "onnx") + model_path = os.path.join(onnx_path, model_name) + return (model_path, model_name) + + +NODE_CLASS_MAPPING = { + "ONNX_EXPORT": ONNXExport, + "ONNXModelSelector": ONNXModelSelector, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ONNX_EXPORT": "ONNX Export", + "ONNXModelSelector": "Select ONNX Model", +} diff --git a/onnx_utils/export.py b/onnx_utils/export.py new file mode 100644 index 0000000..74882f8 --- /dev/null +++ b/onnx_utils/export.py @@ -0,0 +1,236 @@ +import json +import os +import time +from typing import List + +import comfy +import folder_paths +import numpy as np +import onnx +import torch +from onnx import numpy_helper +from onnx.external_data_helper import _get_all_tensors, ExternalDataInfo + +from ..models import detect_version_from_model, get_helper_from_model + + +def _get_onnx_external_data_tensors(model: onnx.ModelProto) -> List[str]: + """ + Gets the paths of the external data tensors in the model. + Note: make sure you load the model with load_external_data=False. + """ + model_tensors = _get_all_tensors(model) + model_tensors_ext = [ + ExternalDataInfo(tensor).location + for tensor in model_tensors + if tensor.HasField("data_location") + and tensor.data_location == onnx.TensorProto.EXTERNAL + ] + return model_tensors_ext + + +def get_sample_input(input_shapes: dict, dtype: torch.dtype, device: torch.device): + inputs = [] + for k, shape in input_shapes.items(): + inputs.append( + torch.zeros( + shape, + device=device, + dtype=dtype, + ) + ) + + return tuple(inputs) + + +def get_backbone(model, model_version, input_names, num_video_frames, use_control): + unet = model.model.diffusion_model + transformer_options = model.model_options["transformer_options"].copy() + + if model_version == "SVD_img2vid": + + class UNET(torch.nn.Module): + def forward(self, x, timesteps, context, y): + return self.unet( + x, + timesteps, + context, + y, + num_video_frames=self.num_video_frames, + transformer_options=self.transformer_options, + ) + + svd_unet = UNET() + svd_unet.num_video_frames = num_video_frames + svd_unet.unet = unet + svd_unet.transformer_options = transformer_options + unet = svd_unet + else: + + class UNET(torch.nn.Module): + def forward(self, x, timesteps, context, *args): + extras = input_names[3:] + control = {"input": [], "output": [], "middle": []} + extra_args = {} + for i in range(len(extras)): + if "control" in extras[i]: + if "input" in extras[i]: + control["input"].append(args[i]) + elif "output" in extras[i]: + control["output"].append(args[i]) + elif "middle" in extras[i]: + control["middle"].append(args[i]) + else: + extra_args[extras[i]] = args[i] + if use_control: + extra_args["control"] = control + return self.unet( + x, + timesteps, + context, + transformer_options=self.transformer_options, + **extra_args, + ) + + _unet = UNET() + _unet.unet = unet + _unet.transformer_options = transformer_options + unet = _unet + + return unet + + +# Helper utility for weights map +def export_weights_map(state_dict, onnx_opt_path: str, weights_map_path: str): + onnx_opt_dir = onnx_opt_path + onnx_opt_model = onnx.load(onnx_opt_path) + + # Create initializer data hashes + def init_hash_map(onnx_opt_model): + initializer_hash_mapping = {} + for initializer in onnx_opt_model.graph.initializer: + initializer_data = numpy_helper.to_array( + initializer, base_dir=onnx_opt_dir + ).astype(np.float16) + initializer_hash = hash(initializer_data.data.tobytes()) + initializer_hash_mapping[initializer.name] = ( + initializer_hash, + initializer_data.shape, + ) + return initializer_hash_mapping + + initializer_hash_mapping = init_hash_map(onnx_opt_model) + + weights_name_mapping = {} + weights_shape_mapping = {} + # set to keep track of initializers already added to the name_mapping dict + initializers_mapped = set() + for wt_name, wt in state_dict.items(): + # get weight hash + wt = wt.cpu().detach().numpy().astype(np.float16) + wt_hash = hash(wt.data.tobytes()) + wt_t_hash = hash(np.transpose(wt).data.tobytes()) + + for initializer_name, ( + initializer_hash, + initializer_shape, + ) in initializer_hash_mapping.items(): + # Due to constant folding, some weights are transposed during export + # To account for the transpose op, we compare the initializer hash to the + # hash for the weight and its transpose + if wt_hash == initializer_hash or wt_t_hash == initializer_hash: + # The assert below ensures there is a 1:1 mapping between + # PyTorch and ONNX weight names. It can be removed in cases where 1:many + # mapping is found and name_mapping[wt_name] = list() + assert initializer_name not in initializers_mapped + weights_name_mapping[wt_name] = initializer_name + initializers_mapped.add(initializer_name) + is_transpose = False if wt_hash == initializer_hash else True + weights_shape_mapping[wt_name] = ( + initializer_shape, + is_transpose, + ) + + # Sanity check: Were any weights not matched + if wt_name not in weights_name_mapping: + print(f"[I] PyTorch weight {wt_name} not matched with any ONNX initializer") + print( + f"[I] UNet: {len(weights_name_mapping.keys())} PyTorch weights were matched with ONNX initializers" + ) + + assert weights_name_mapping.keys() == weights_shape_mapping.keys() + with open(weights_map_path, "w") as fp: + json.dump([weights_name_mapping, weights_shape_mapping], fp) + + +def export_onnx( + model, + path, + batch_size: int = 1, + height: int = 512, + width: int = 512, + num_video_frames: int = 14, + context_multiplier: int = 1, +): + model_version = detect_version_from_model(model) + model_helper = get_helper_from_model(model) + + dtype = model_helper.get_dtype() + device = comfy.model_management.get_torch_device() + input_names = model_helper.get_input_names() + output_names = model_helper.get_output_names() + dynamic_axes = model_helper.get_dynamic_axes() + + if model_version == "SVD_img2vid": + batch_size *= num_video_frames + if model_helper.is_conditional: + batch_size *= 2 + input_shapes = model_helper.get_input_shapes( + batch_size=batch_size, + height=height, + width=width, + context_multiplier=context_multiplier, + ) + inputs = get_sample_input(input_shapes, dtype, device) + backbone = get_backbone( + model, model_version, input_names, num_video_frames, model_helper.use_control + ) + + _, name = os.path.split(path) + temp_path = os.path.join( + folder_paths.get_temp_directory(), "{}".format(time.time()) + ) + onnx_temp = os.path.normpath(os.path.join(temp_path, name)) + + if not os.path.exists(temp_path): + os.makedirs(temp_path) + + torch.onnx.export( + backbone, + inputs, + onnx_temp, + verbose=False, + input_names=input_names, + output_names=output_names, + opset_version=17, + dynamic_axes=dynamic_axes, + ) + + comfy.model_management.unload_all_models() + comfy.model_management.soft_empty_cache() + + onnx_model = onnx.load(onnx_temp, load_external_data=True) + tensors_paths = _get_onnx_external_data_tensors(onnx_model) + + if tensors_paths: + for tensor in tensors_paths: + os.remove(os.path.join(onnx_temp, tensor)) + + onnx.save( + onnx_model, + path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=name + "_data", + size_threshold=1024, + ) diff --git a/pyproject.toml b/pyproject.toml index e7adb94..f2ff4f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,8 +4,8 @@ description = "TensorRT Node for ComfyUI\nThis node enables the best performance version = "0.1.8" license = "LICENSE" dependencies = [ - "tensorrt>=10.0.1", - "onnx" + "tensorrt>=10.6", + "onnx" ] [project.urls] diff --git a/requirements.txt b/requirements.txt index cef004a..564aae7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,6 @@ -tensorrt>=10.0.1 +tensorrt>=10.4.0 --extra-index-url https://pypi.nvidia.com onnx!=1.16.2 +onnx-graphsurgeon +onnxmltools +onnxconverter-common +pulp diff --git a/tensorrt_convert.py b/tensorrt_convert.py deleted file mode 100644 index 6ca4bff..0000000 --- a/tensorrt_convert.py +++ /dev/null @@ -1,650 +0,0 @@ -import torch -import sys -import os -import time -import comfy.model_management - -import tensorrt as trt -import folder_paths -from tqdm import tqdm - -# TODO: -# Make it more generic: less model specific code - -# add output directory to tensorrt search path -if "tensorrt" in folder_paths.folder_names_and_paths: - folder_paths.folder_names_and_paths["tensorrt"][0].append( - os.path.join(folder_paths.get_output_directory(), "tensorrt") - ) - folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") -else: - folder_paths.folder_names_and_paths["tensorrt"] = ( - [os.path.join(folder_paths.get_output_directory(), "tensorrt")], - {".engine"}, - ) - -class TQDMProgressMonitor(trt.IProgressMonitor): - def __init__(self): - trt.IProgressMonitor.__init__(self) - self._active_phases = {} - self._step_result = True - self.max_indent = 5 - - def phase_start(self, phase_name, parent_phase, num_steps): - leave = False - try: - if parent_phase is not None: - nbIndents = ( - self._active_phases.get(parent_phase, {}).get( - "nbIndents", self.max_indent - ) - + 1 - ) - if nbIndents >= self.max_indent: - return - else: - nbIndents = 0 - leave = True - self._active_phases[phase_name] = { - "tq": tqdm( - total=num_steps, desc=phase_name, leave=leave, position=nbIndents - ), - "nbIndents": nbIndents, - "parent_phase": parent_phase, - } - except KeyboardInterrupt: - # The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete. - _step_result = False - - def phase_finish(self, phase_name): - try: - if phase_name in self._active_phases.keys(): - self._active_phases[phase_name]["tq"].update( - self._active_phases[phase_name]["tq"].total - - self._active_phases[phase_name]["tq"].n - ) - - parent_phase = self._active_phases[phase_name].get("parent_phase", None) - while parent_phase is not None: - self._active_phases[parent_phase]["tq"].refresh() - parent_phase = self._active_phases[parent_phase].get( - "parent_phase", None - ) - if ( - self._active_phases[phase_name]["parent_phase"] - in self._active_phases.keys() - ): - self._active_phases[ - self._active_phases[phase_name]["parent_phase"] - ]["tq"].refresh() - del self._active_phases[phase_name] - pass - except KeyboardInterrupt: - _step_result = False - - def step_complete(self, phase_name, step): - try: - if phase_name in self._active_phases.keys(): - self._active_phases[phase_name]["tq"].update( - step - self._active_phases[phase_name]["tq"].n - ) - return self._step_result - except KeyboardInterrupt: - # There is no need to propagate this exception to TensorRT. We can simply cancel the build. - return False - - -class TRT_MODEL_CONVERSION_BASE: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.temp_dir = folder_paths.get_temp_directory() - self.timing_cache_path = os.path.normpath( - os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt")) - ) - - RETURN_TYPES = () - FUNCTION = "convert" - OUTPUT_NODE = True - CATEGORY = "TensorRT" - - @classmethod - def INPUT_TYPES(s): - raise NotImplementedError - - # Sets up the builder to use the timing cache file, and creates it if it does not already exist - def _setup_timing_cache(self, config: trt.IBuilderConfig): - buffer = b"" - if os.path.exists(self.timing_cache_path): - with open(self.timing_cache_path, mode="rb") as timing_cache_file: - buffer = timing_cache_file.read() - print("Read {} bytes from timing cache.".format(len(buffer))) - else: - print("No timing cache found; Initializing a new one.") - timing_cache: trt.ITimingCache = config.create_timing_cache(buffer) - config.set_timing_cache(timing_cache, ignore_mismatch=True) - - # Saves the config's timing cache to file - def _save_timing_cache(self, config: trt.IBuilderConfig): - timing_cache: trt.ITimingCache = config.get_timing_cache() - with open(self.timing_cache_path, "wb") as timing_cache_file: - timing_cache_file.write(memoryview(timing_cache.serialize())) - - def _convert( - self, - model, - filename_prefix, - batch_size_min, - batch_size_opt, - batch_size_max, - height_min, - height_opt, - height_max, - width_min, - width_opt, - width_max, - context_min, - context_opt, - context_max, - num_video_frames, - is_static: bool, - ): - output_onnx = os.path.normpath( - os.path.join( - os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx" - ) - ) - - comfy.model_management.unload_all_models() - comfy.model_management.load_models_gpu([model], force_patch_weights=True, force_full_load=True) - unet = model.model.diffusion_model - - context_dim = model.model.model_config.unet_config.get("context_dim", None) - context_len = 77 - context_len_min = context_len - y_dim = model.model.adm_channels - extra_input = {} - dtype = torch.float16 - - if isinstance(model.model, comfy.model_base.SD3): #SD3 - context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None) - if context_embedder_config is not None: - context_dim = context_embedder_config.get("params", {}).get("in_features", None) - context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77 - elif isinstance(model.model, comfy.model_base.AuraFlow): - context_dim = 2048 - context_len_min = 256 - context_len = 256 - elif isinstance(model.model, comfy.model_base.Flux): - context_dim = model.model.model_config.unet_config.get("context_in_dim", None) - context_len_min = 256 - context_len = 256 - y_dim = model.model.model_config.unet_config.get("vec_in_dim", None) - extra_input = {"guidance": ()} - dtype = torch.bfloat16 - - if context_dim is not None: - input_names = ["x", "timesteps", "context"] - output_names = ["h"] - - dynamic_axes = { - "x": {0: "batch", 2: "height", 3: "width"}, - "timesteps": {0: "batch"}, - "context": {0: "batch", 1: "num_embeds"}, - } - - transformer_options = model.model_options['transformer_options'].copy() - if model.model.model_config.unet_config.get( - "use_temporal_resblock", False - ): # SVD - batch_size_min = num_video_frames * batch_size_min - batch_size_opt = num_video_frames * batch_size_opt - batch_size_max = num_video_frames * batch_size_max - - class UNET(torch.nn.Module): - def forward(self, x, timesteps, context, y): - return self.unet( - x, - timesteps, - context, - y, - num_video_frames=self.num_video_frames, - transformer_options=self.transformer_options, - ) - - svd_unet = UNET() - svd_unet.num_video_frames = num_video_frames - svd_unet.unet = unet - svd_unet.transformer_options = transformer_options - unet = svd_unet - context_len_min = context_len = 1 - else: - class UNET(torch.nn.Module): - def forward(self, x, timesteps, context, *args): - extras = input_names[3:] - extra_args = {} - for i in range(len(extras)): - extra_args[extras[i]] = args[i] - return self.unet(x, timesteps, context, transformer_options=self.transformer_options, **extra_args) - - _unet = UNET() - _unet.unet = unet - _unet.transformer_options = transformer_options - unet = _unet - - input_channels = model.model.model_config.unet_config.get("in_channels", 4) - - inputs_shapes_min = ( - (batch_size_min, input_channels, height_min // 8, width_min // 8), - (batch_size_min,), - (batch_size_min, context_len_min * context_min, context_dim), - ) - inputs_shapes_opt = ( - (batch_size_opt, input_channels, height_opt // 8, width_opt // 8), - (batch_size_opt,), - (batch_size_opt, context_len * context_opt, context_dim), - ) - inputs_shapes_max = ( - (batch_size_max, input_channels, height_max // 8, width_max // 8), - (batch_size_max,), - (batch_size_max, context_len * context_max, context_dim), - ) - - if y_dim > 0: - input_names.append("y") - dynamic_axes["y"] = {0: "batch"} - inputs_shapes_min += ((batch_size_min, y_dim),) - inputs_shapes_opt += ((batch_size_opt, y_dim),) - inputs_shapes_max += ((batch_size_max, y_dim),) - - for k in extra_input: - input_names.append(k) - dynamic_axes[k] = {0: "batch"} - inputs_shapes_min += ((batch_size_min,) + extra_input[k],) - inputs_shapes_opt += ((batch_size_opt,) + extra_input[k],) - inputs_shapes_max += ((batch_size_max,) + extra_input[k],) - - - inputs = () - for shape in inputs_shapes_opt: - inputs += ( - torch.zeros( - shape, - device=comfy.model_management.get_torch_device(), - dtype=dtype, - ), - ) - - else: - print("ERROR: model not supported.") - return () - - os.makedirs(os.path.dirname(output_onnx), exist_ok=True) - torch.onnx.export( - unet, - inputs, - output_onnx, - verbose=False, - input_names=input_names, - output_names=output_names, - opset_version=17, - dynamic_axes=dynamic_axes, - ) - - comfy.model_management.unload_all_models() - comfy.model_management.soft_empty_cache() - - # TRT conversion starts here - logger = trt.Logger(trt.Logger.INFO) - builder = trt.Builder(logger) - - network = builder.create_network( - 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - ) - parser = trt.OnnxParser(network, logger) - success = parser.parse_from_file(output_onnx) - for idx in range(parser.num_errors): - print(parser.get_error(idx)) - - if not success: - print("ONNX load ERROR") - return () - - config = builder.create_builder_config() - profile = builder.create_optimization_profile() - self._setup_timing_cache(config) - config.progress_monitor = TQDMProgressMonitor() - - prefix_encode = "" - for k in range(len(input_names)): - min_shape = inputs_shapes_min[k] - opt_shape = inputs_shapes_opt[k] - max_shape = inputs_shapes_max[k] - profile.set_shape(input_names[k], min_shape, opt_shape, max_shape) - - # Encode shapes to filename - encode = lambda a: ".".join(map(lambda x: str(x), a)) - prefix_encode += "{}#{}#{}#{};".format( - input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape) - ) - - if dtype == torch.float16: - config.set_flag(trt.BuilderFlag.FP16) - if dtype == torch.bfloat16: - config.set_flag(trt.BuilderFlag.BF16) - - config.add_optimization_profile(profile) - - if is_static: - filename_prefix = "{}_${}".format( - filename_prefix, - "-".join( - ( - "stat", - "b", - str(batch_size_opt), - "h", - str(height_opt), - "w", - str(width_opt), - ) - ), - ) - else: - filename_prefix = "{}_${}".format( - filename_prefix, - "-".join( - ( - "dyn", - "b", - str(batch_size_min), - str(batch_size_max), - str(batch_size_opt), - "h", - str(height_min), - str(height_max), - str(height_opt), - "w", - str(width_min), - str(width_max), - str(width_opt), - ) - ), - ) - - serialized_engine = builder.build_serialized_network(network, config) - - full_output_folder, filename, counter, subfolder, filename_prefix = ( - folder_paths.get_save_image_path(filename_prefix, self.output_dir) - ) - output_trt_engine = os.path.join( - full_output_folder, f"{filename}_{counter:05}_.engine" - ) - - with open(output_trt_engine, "wb") as f: - f.write(serialized_engine) - - self._save_timing_cache(config) - - return () - - -class DYNAMIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE): - def __init__(self): - super(DYNAMIC_TRT_MODEL_CONVERSION, self).__init__() - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_DYN"}), - "batch_size_min": ( - "INT", - { - "default": 1, - "min": 1, - "max": 100, - "step": 1, - }, - ), - "batch_size_opt": ( - "INT", - { - "default": 1, - "min": 1, - "max": 100, - "step": 1, - }, - ), - "batch_size_max": ( - "INT", - { - "default": 1, - "min": 1, - "max": 100, - "step": 1, - }, - ), - "height_min": ( - "INT", - { - "default": 512, - "min": 256, - "max": 4096, - "step": 64, - }, - ), - "height_opt": ( - "INT", - { - "default": 512, - "min": 256, - "max": 4096, - "step": 64, - }, - ), - "height_max": ( - "INT", - { - "default": 512, - "min": 256, - "max": 4096, - "step": 64, - }, - ), - "width_min": ( - "INT", - { - "default": 512, - "min": 256, - "max": 4096, - "step": 64, - }, - ), - "width_opt": ( - "INT", - { - "default": 512, - "min": 256, - "max": 4096, - "step": 64, - }, - ), - "width_max": ( - "INT", - { - "default": 512, - "min": 256, - "max": 4096, - "step": 64, - }, - ), - "context_min": ( - "INT", - { - "default": 1, - "min": 1, - "max": 128, - "step": 1, - }, - ), - "context_opt": ( - "INT", - { - "default": 1, - "min": 1, - "max": 128, - "step": 1, - }, - ), - "context_max": ( - "INT", - { - "default": 1, - "min": 1, - "max": 128, - "step": 1, - }, - ), - "num_video_frames": ( - "INT", - { - "default": 14, - "min": 0, - "max": 1000, - "step": 1, - }, - ), - }, - } - - def convert( - self, - model, - filename_prefix, - batch_size_min, - batch_size_opt, - batch_size_max, - height_min, - height_opt, - height_max, - width_min, - width_opt, - width_max, - context_min, - context_opt, - context_max, - num_video_frames, - ): - return super()._convert( - model, - filename_prefix, - batch_size_min, - batch_size_opt, - batch_size_max, - height_min, - height_opt, - height_max, - width_min, - width_opt, - width_max, - context_min, - context_opt, - context_max, - num_video_frames, - is_static=False, - ) - - -class STATIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE): - def __init__(self): - super(STATIC_TRT_MODEL_CONVERSION, self).__init__() - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_STAT"}), - "batch_size_opt": ( - "INT", - { - "default": 1, - "min": 1, - "max": 100, - "step": 1, - }, - ), - "height_opt": ( - "INT", - { - "default": 512, - "min": 256, - "max": 4096, - "step": 64, - }, - ), - "width_opt": ( - "INT", - { - "default": 512, - "min": 256, - "max": 4096, - "step": 64, - }, - ), - "context_opt": ( - "INT", - { - "default": 1, - "min": 1, - "max": 128, - "step": 1, - }, - ), - "num_video_frames": ( - "INT", - { - "default": 14, - "min": 0, - "max": 1000, - "step": 1, - }, - ), - }, - } - - def convert( - self, - model, - filename_prefix, - batch_size_opt, - height_opt, - width_opt, - context_opt, - num_video_frames, - ): - return super()._convert( - model, - filename_prefix, - batch_size_opt, - batch_size_opt, - batch_size_opt, - height_opt, - height_opt, - height_opt, - width_opt, - width_opt, - width_opt, - context_opt, - context_opt, - context_opt, - num_video_frames, - is_static=True, - ) - - -NODE_CLASS_MAPPINGS = { - "DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION, - "STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION, -} diff --git a/tensorrt_diffusion_model.py b/tensorrt_diffusion_model.py new file mode 100644 index 0000000..b3c4d4c --- /dev/null +++ b/tensorrt_diffusion_model.py @@ -0,0 +1,355 @@ +import os +from math import prod +from typing import Optional + +import comfy.model_management +import tensorrt as trt +import torch +from tqdm import tqdm + +from .models import get_model_from_version, TRTModelUtil + +trt.init_libnvinfer_plugins(None, "") +logger = trt.Logger(trt.Logger.INFO) +runtime = trt.Runtime(logger) + + +def trt_datatype_to_torch(datatype): + if datatype == trt.float16: + return torch.float16 + elif datatype == trt.float32: + return torch.float32 + elif datatype == trt.int32: + return torch.int32 + elif datatype == trt.bfloat16: + return torch.bfloat16 + + +class TRTModel(torch.nn.Module): + def __init__(self, model_helper: TRTModelUtil, *args, **kwargs) -> None: + super(TRTModel, self).__init__() + + self.context = None + self.engine = None + + self.model = model_helper + for k, v in self.model.model_attributes().items(): + setattr(self, k, v) + + self.dtype = self.model.get_dtype() + self.device = comfy.model_management.get_torch_device() + + self.input_names = self.model.get_input_names() + self.output_names = self.model.get_output_names() + + self.current_shape: tuple[int] = (0,) + self.output_shapes: dict[str, tuple[int]] = {} + self.curr_split_batch: int = 0 + + self.zero_pool = None + self.extra_inputs: dict[str, torch.Tensor] = {} + + # Sets up the builder to use the timing cache file, and creates it if it does not already exist + @staticmethod + def _setup_timing_cache(config: trt.IBuilderConfig, timing_cache_path: str): + buffer = b"" + if os.path.exists(timing_cache_path): + with open(timing_cache_path, mode="rb") as timing_cache_file: + buffer = timing_cache_file.read() + print("Read {} bytes from timing cache.".format(len(buffer))) + else: + print("No timing cache found; Initializing a new one.") + timing_cache: trt.ITimingCache = config.create_timing_cache(buffer) + config.set_timing_cache(timing_cache, ignore_mismatch=True) + + # Saves the config's timing cache to file + @staticmethod + def _save_timing_cache(config: trt.IBuilderConfig, timing_cache_path: str): + timing_cache: trt.ITimingCache = config.get_timing_cache() + with open(timing_cache_path, "wb") as timing_cache_file: + timing_cache_file.write(memoryview(timing_cache.serialize())) + + def _create_profile(self, builder, min_config, opt_config, max_config): + profile = builder.create_optimization_profile() + + min_config = opt_config if min_config is None else min_config + max_config = opt_config if min_config is None else max_config + + min_shapes = self.model.get_input_shapes(**min_config) + opt_shapes = self.model.get_input_shapes(**opt_config) + max_shapes = self.model.get_input_shapes(**max_config) + for input_name in opt_shapes.keys(): + profile.set_shape( + input_name, + min_shapes[input_name], + opt_shapes[input_name], + max_shapes[input_name], + ) + + return profile + + def build( + self, + onnx_path: str, + engine_path: str, + timing_cache_path: str, + opt_config: dict, + min_config: Optional[dict] = None, + max_config: Optional[dict] = None, + ) -> bool: + comfy.model_management.unload_all_models() + comfy.model_management.soft_empty_cache() + + builder = trt.Builder(logger) + + network = builder.create_network( + 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) + ) + parser = trt.OnnxParser(network, logger) + success = parser.parse_from_file(onnx_path) + for idx in range(parser.num_errors): + print(parser.get_error(idx)) + + if not success: + print("ONNX load ERROR") + return False + + config = builder.create_builder_config() + # config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED + self._setup_timing_cache(config, timing_cache_path) + config.progress_monitor = TQDMProgressMonitor() + profile = self._create_profile(builder, min_config, opt_config, max_config) + config.add_optimization_profile(profile) + + config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) # STRIP_PLAN + + self.engine = builder.build_serialized_network(network, config) + if self.engine is None: + raise Exception("Failed to build Engine") + + engine_view = memoryview(self.engine) + model = { + "engine": torch.frombuffer(engine_view, dtype=torch.uint8), + "config": self.model.to_dict(), + } + torch.save(model, engine_path) + + return True + + @torch.cuda.nvtx.range("set_bindings_shape") + def set_bindings_shape(self, inputs): + for k in inputs: + shape = inputs[k].shape + shape = [shape[0] // self.curr_split_batch] + list(shape[1:]) + self.context.set_input_shape(k, shape) + + @torch.cuda.nvtx.range("setup_tensors") + def setup_tensors(self, model_inputs): + raise NotImplementedError + + @classmethod + def load_trt_model(cls, engine_path, model_type): + try: + engine = torch.load(engine_path) + config = engine["config"] + model, helper = get_model_from_version(model_type, config) + unet = cls(helper) + unet.engine = runtime.deserialize_cuda_engine( + engine["engine"].numpy().tobytes() + ) + except: + model, helper = get_model_from_version(model_type, {}) + unet = cls(helper) + with open(engine_path, "rb") as f: + unet.engine = runtime.deserialize_cuda_engine(f.read()) + + if unet.engine is None: + raise Exception("Failed to load Engine") + + unet.context = unet.engine.create_execution_context() + model.diffusion_model = unet + model.memory_required = ( + lambda *args, **kwargs: 0 + ) # always pass inputs batched up as much as possible + return model + + @torch.cuda.nvtx.range("__call__") + def __call__(self): + raise NotImplementedError + + +class TRTDiffusionBackbone(TRTModel): + + @torch.cuda.nvtx.range("setup_tensors") + def setup_tensors(self, model_inputs): + self.current_shape = model_inputs["x"].shape + self.extra_inputs = {} + + dims = self.engine.get_tensor_profile_shape(self.engine.get_tensor_name(0), 0) + batch_size, _, height, width = self.current_shape + height *= 8 + width *= 8 + _, context_len, _ = model_inputs["context"].shape + min_batch = dims[0][0] + max_batch = dims[2][0] + # Split batch if our batch is bigger than the max batch size the trt engine supports + for i in range(max_batch, min_batch - 1, -1): + if batch_size % i == 0: + self.curr_split_batch = batch_size // i + break + + self.set_bindings_shape(model_inputs) + + # Inputs missing, use zero + max_memory = 0 + for name in self.input_names: + if name in model_inputs: + continue + shape = self.model.get_input_shapes_by_key( + name, + batch_size=batch_size, + height=height, + width=width, + context_len=context_len, + ) + shape = (shape[0] // self.curr_split_batch, *shape[1:]) + self.context.set_input_shape(name, shape) + self.extra_inputs[name] = 0 + max_memory = max(prod(shape), max_memory) + self.zero_pool = torch.zeros( + max_memory, device=self.device, dtype=self.dtype + ).contiguous() + + self.output_shapes = {} + for name in self.output_names: + shape = list(self.engine.get_tensor_shape("h")) + for idx in range(len(shape)): + if shape[idx] == -1: + shape[idx] = model_inputs["x"].shape[idx] + if idx == 0: + shape[idx] = batch_size + self.output_shapes[name] = shape + + @torch.cuda.nvtx.range("__call__") + def __call__( + self, + x, + timesteps, + context, + y=None, + control=None, + transformer_options=None, + **kwargs, + ): + model_inputs = {"x": x, "timesteps": timesteps, "context": context} + + if y is not None: + model_inputs["y"] = y + + if self.model.use_control and control is not None: + for control_layer, control_tensors in control.items(): + for i, tensor in enumerate(control_tensors): + model_inputs[f"{control_layer}_control_{i}"] = tensor + + # TODO actually needed? + # for k, v in kwargs.items(): + # model_inputs[k] = v + + if self.current_shape != x.shape: + self.setup_tensors(model_inputs) + + model_inputs["h"] = torch.empty( + self.output_shapes["h"], + device=self.device, + dtype=trt_datatype_to_torch(self.engine.get_tensor_dtype("h")), + ).contiguous() + + for k in model_inputs: + trt_dtype = trt_datatype_to_torch(self.engine.get_tensor_dtype(k)) + if model_inputs[k].dtype != trt_dtype: + model_inputs[k] = model_inputs[k].to(trt_dtype) + + torch.cuda.nvtx.range_push("infer") + stream = torch.cuda.default_stream(x.device) + for i in range(self.curr_split_batch): + for k, v in model_inputs.items(): + self.context.set_tensor_address( + k, v[(v.shape[0] // self.curr_split_batch) * i:].data_ptr() + ) + for k in self.extra_inputs.keys(): + self.context.set_tensor_address(k, self.zero_pool.data_ptr()) + self.context.execute_async_v3(stream_handle=stream.cuda_stream) + torch.cuda.nvtx.range_pop() + + return model_inputs["h"] + + +class TQDMProgressMonitor(trt.IProgressMonitor): + def __init__(self): + trt.IProgressMonitor.__init__(self) + self._active_phases = {} + self._step_result = True + self.max_indent = 5 + + def phase_start(self, phase_name, parent_phase, num_steps): + leave = False + try: + if parent_phase is not None: + nb_indents = ( + self._active_phases.get(parent_phase, {}).get( + "nbIndents", self.max_indent + ) + + 1 + ) + if nb_indents >= self.max_indent: + return + else: + nb_indents = 0 + leave = True + self._active_phases[phase_name] = { + "tq": tqdm( + total=num_steps, desc=phase_name, leave=leave, position=nb_indents + ), + "nbIndents": nb_indents, + "parent_phase": parent_phase, + } + except KeyboardInterrupt: + # The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete. + _step_result = False + + def phase_finish(self, phase_name): + try: + if phase_name in self._active_phases.keys(): + self._active_phases[phase_name]["tq"].update( + self._active_phases[phase_name]["tq"].total + - self._active_phases[phase_name]["tq"].n + ) + + parent_phase = self._active_phases[phase_name].get("parent_phase", None) + while parent_phase is not None: + self._active_phases[parent_phase]["tq"].refresh() + parent_phase = self._active_phases[parent_phase].get( + "parent_phase", None + ) + if ( + self._active_phases[phase_name]["parent_phase"] + in self._active_phases.keys() + ): + self._active_phases[ + self._active_phases[phase_name]["parent_phase"] + ]["tq"].refresh() + del self._active_phases[phase_name] + pass + except KeyboardInterrupt: + _step_result = False + + def step_complete(self, phase_name, step): + try: + if phase_name in self._active_phases.keys(): + self._active_phases[phase_name]["tq"].update( + step - self._active_phases[phase_name]["tq"].n + ) + return self._step_result + except KeyboardInterrupt: + # There is no need to propagate this exception to TensorRT. We can simply cancel the build. + return False diff --git a/tensorrt_loader.py b/tensorrt_loader.py deleted file mode 100644 index 5e2ccac..0000000 --- a/tensorrt_loader.py +++ /dev/null @@ -1,176 +0,0 @@ -#Put this in the custom_nodes folder, put your tensorrt engine files in ComfyUI/models/tensorrt/ (you will have to create the directory) - -import torch -import os - -import comfy.model_base -import comfy.model_management -import comfy.model_patcher -import comfy.supported_models -import folder_paths - -if "tensorrt" in folder_paths.folder_names_and_paths: - folder_paths.folder_names_and_paths["tensorrt"][0].append( - os.path.join(folder_paths.models_dir, "tensorrt")) - folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") -else: - folder_paths.folder_names_and_paths["tensorrt"] = ( - [os.path.join(folder_paths.models_dir, "tensorrt")], {".engine"}) - -import tensorrt as trt - -trt.init_libnvinfer_plugins(None, "") - -logger = trt.Logger(trt.Logger.INFO) -runtime = trt.Runtime(logger) - -# Is there a function that already exists for this? -def trt_datatype_to_torch(datatype): - if datatype == trt.float16: - return torch.float16 - elif datatype == trt.float32: - return torch.float32 - elif datatype == trt.int32: - return torch.int32 - elif datatype == trt.bfloat16: - return torch.bfloat16 - -class TrTUnet: - def __init__(self, engine_path): - with open(engine_path, "rb") as f: - self.engine = runtime.deserialize_cuda_engine(f.read()) - self.context = self.engine.create_execution_context() - self.dtype = torch.float16 - - def set_bindings_shape(self, inputs, split_batch): - for k in inputs: - shape = inputs[k].shape - shape = [shape[0] // split_batch] + list(shape[1:]) - self.context.set_input_shape(k, shape) - - def __call__(self, x, timesteps, context, y=None, control=None, transformer_options=None, **kwargs): - model_inputs = {"x": x, "timesteps": timesteps, "context": context} - - if y is not None: - model_inputs["y"] = y - - for i in range(len(model_inputs), self.engine.num_io_tensors - 1): - name = self.engine.get_tensor_name(i) - model_inputs[name] = kwargs[name] - - batch_size = x.shape[0] - dims = self.engine.get_tensor_profile_shape(self.engine.get_tensor_name(0), 0) - min_batch = dims[0][0] - opt_batch = dims[1][0] - max_batch = dims[2][0] - - #Split batch if our batch is bigger than the max batch size the trt engine supports - for i in range(max_batch, min_batch - 1, -1): - if batch_size % i == 0: - curr_split_batch = batch_size // i - break - - self.set_bindings_shape(model_inputs, curr_split_batch) - - model_inputs_converted = {} - for k in model_inputs: - data_type = self.engine.get_tensor_dtype(k) - model_inputs_converted[k] = model_inputs[k].to(dtype=trt_datatype_to_torch(data_type)) - - output_binding_name = self.engine.get_tensor_name(len(model_inputs)) - out_shape = self.engine.get_tensor_shape(output_binding_name) - out_shape = list(out_shape) - - #for dynamic profile case where the dynamic params are -1 - for idx in range(len(out_shape)): - if out_shape[idx] == -1: - out_shape[idx] = x.shape[idx] - else: - if idx == 0: - out_shape[idx] *= curr_split_batch - - out = torch.empty(out_shape, - device=x.device, - dtype=trt_datatype_to_torch(self.engine.get_tensor_dtype(output_binding_name))) - model_inputs_converted[output_binding_name] = out - - stream = torch.cuda.default_stream(x.device) - for i in range(curr_split_batch): - for k in model_inputs_converted: - x = model_inputs_converted[k] - self.context.set_tensor_address(k, x[(x.shape[0] // curr_split_batch) * i:].data_ptr()) - self.context.execute_async_v3(stream_handle=stream.cuda_stream) - # stream.synchronize() #don't need to sync stream since it's the default torch one - return out - - def load_state_dict(self, sd, strict=False): - pass - - def state_dict(self): - return {} - - -class TensorRTLoader: - @classmethod - def INPUT_TYPES(s): - return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ), - "model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow", "flux_dev", "flux_schnell"], ), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "load_unet" - CATEGORY = "TensorRT" - - def load_unet(self, unet_name, model_type): - unet_path = folder_paths.get_full_path("tensorrt", unet_name) - if not os.path.isfile(unet_path): - raise FileNotFoundError(f"File {unet_path} does not exist") - unet = TrTUnet(unet_path) - if model_type == "sdxl_base": - conf = comfy.supported_models.SDXL({"adm_in_channels": 2816}) - conf.unet_config["disable_unet_model_creation"] = True - model = comfy.model_base.SDXL(conf) - elif model_type == "sdxl_refiner": - conf = comfy.supported_models.SDXLRefiner( - {"adm_in_channels": 2560}) - conf.unet_config["disable_unet_model_creation"] = True - model = comfy.model_base.SDXLRefiner(conf) - elif model_type == "sd1.x": - conf = comfy.supported_models.SD15({}) - conf.unet_config["disable_unet_model_creation"] = True - model = comfy.model_base.BaseModel(conf) - elif model_type == "sd2.x-768v": - conf = comfy.supported_models.SD20({}) - conf.unet_config["disable_unet_model_creation"] = True - model = comfy.model_base.BaseModel(conf, model_type=comfy.model_base.ModelType.V_PREDICTION) - elif model_type == "svd": - conf = comfy.supported_models.SVD_img2vid({}) - conf.unet_config["disable_unet_model_creation"] = True - model = conf.get_model({}) - elif model_type == "sd3": - conf = comfy.supported_models.SD3({}) - conf.unet_config["disable_unet_model_creation"] = True - model = conf.get_model({}) - elif model_type == "auraflow": - conf = comfy.supported_models.AuraFlow({}) - conf.unet_config["disable_unet_model_creation"] = True - model = conf.get_model({}) - elif model_type == "flux_dev": - conf = comfy.supported_models.Flux({}) - conf.unet_config["disable_unet_model_creation"] = True - model = conf.get_model({}) - unet.dtype = torch.bfloat16 #TODO: autodetect - elif model_type == "flux_schnell": - conf = comfy.supported_models.FluxSchnell({}) - conf.unet_config["disable_unet_model_creation"] = True - model = conf.get_model({}) - unet.dtype = torch.bfloat16 #TODO: autodetect - model.diffusion_model = unet - model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting - - return (comfy.model_patcher.ModelPatcher(model, - load_device=comfy.model_management.get_torch_device(), - offload_device=comfy.model_management.unet_offload_device()),) - -NODE_CLASS_MAPPINGS = { - "TensorRTLoader": TensorRTLoader, -} diff --git a/tensorrt_nodes.py b/tensorrt_nodes.py new file mode 100644 index 0000000..b41be31 --- /dev/null +++ b/tensorrt_nodes.py @@ -0,0 +1,466 @@ +# Put this in the custom_nodes folder, put your tensorrt engine files in ComfyUI/models/tensorrt/ (you will have to create the directory) +import os +import time +from typing import Optional + +import comfy.model_management +import comfy.model_patcher +import folder_paths + +from .models import supported_models, detect_version_from_model, get_helper_from_model +from .onnx_utils.export import export_onnx +from .tensorrt_diffusion_model import TRTDiffusionBackbone + +# add output directory to tensorrt search path +if "tensorrt" in folder_paths.folder_names_and_paths: + folder_paths.folder_names_and_paths["tensorrt"][0].append( + os.path.join(folder_paths.get_output_directory(), "tensorrt") + ) + folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") +else: + folder_paths.folder_names_and_paths["tensorrt"] = ( + [os.path.join(folder_paths.get_output_directory(), "tensorrt")], + {".engine"}, + ) + + +class TensorRTLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "unet_name": (folder_paths.get_filename_list("tensorrt"),), + "model_type": (list(supported_models.keys()),), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_unet" + CATEGORY = "TensorRT" + + @staticmethod + def load_unet(unet_name, model_type): + unet_path = folder_paths.get_full_path("tensorrt", unet_name) + model = TRTDiffusionBackbone.load_trt_model(unet_path, model_type) + return ( + comfy.model_patcher.ModelPatcher( + model, + load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device(), + ), + ) + + +class TRTBuildBase: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.temp_dir = folder_paths.get_temp_directory() + self.timing_cache_path = os.path.normpath( + os.path.join( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt" + ) + ) + ) + + RETURN_TYPES = () + FUNCTION = "convert" + OUTPUT_NODE = True + CATEGORY = "TensorRT" + + @classmethod + def INPUT_TYPES(s): + raise NotImplementedError + + def _convert( + self, + model, + filename_prefix, + batch_size_min, + batch_size_opt, + batch_size_max, + height_min, + height_opt, + height_max, + width_min, + width_opt, + width_max, + context_min, + context_opt, + context_max, + num_video_frames, + is_static: bool, + output_onnx: Optional[str] = None, + ): + if output_onnx is None: + output_onnx = os.path.normpath( + os.path.join( + os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx" + ) + ) + os.makedirs(os.path.dirname(output_onnx), exist_ok=True) + + comfy.model_management.unload_all_models() + comfy.model_management.load_models_gpu( + [model], force_patch_weights=True, force_full_load=True + ) + export_onnx(model, output_onnx) + + model_version = detect_version_from_model(model) + model_helper = get_helper_from_model(model) + trt_model = TRTDiffusionBackbone(model_helper) + + filename_prefix = f"{filename_prefix}_{model_version}" + if is_static: + filename_prefix = "{}_${}".format( + filename_prefix, + "-".join( + ( + "stat", + "b", + str(batch_size_opt), + "h", + str(height_opt), + "w", + str(width_opt), + ) + ), + ) + else: + filename_prefix = "{}_${}".format( + filename_prefix, + "-".join( + ( + "dyn", + "b", + str(batch_size_min), + str(batch_size_max), + str(batch_size_opt), + "h", + str(height_min), + str(height_max), + str(height_opt), + "w", + str(width_min), + str(width_max), + str(width_opt), + ) + ), + ) + + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path(filename_prefix, self.output_dir) + ) + output_trt_engine = os.path.join( + full_output_folder, f"{filename}_{counter:05}_.engine" + ) + + batch_multiplier = ( + 2 if model_helper.is_conditional else 1 + ) # TODO lets see if we really want this + if model_version == "SVD_img2vid": + batch_multiplier *= num_video_frames + success = trt_model.build( + output_onnx, + output_trt_engine, + self.timing_cache_path, + opt_config={ + "batch_size": batch_size_opt * batch_multiplier, + "height": height_opt, + "width": width_opt, + "context_len": context_opt * model_helper.context_len, + }, + min_config={ + "batch_size": batch_size_min * batch_multiplier, + "height": height_min, + "width": width_min, + "context_len": context_min * model_helper.context_len, + }, + max_config={ + "batch_size": batch_size_max * batch_multiplier, + "height": height_max, + "width": width_max, + "context_len": context_max * model_helper.context_len, + }, + ) + if not success: + raise RuntimeError("Engine Build Failed") + return () + + +class DynamicTRTBuild(TRTBuildBase): + def __init__(self): + super(DynamicTRTBuild, self).__init__() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_DYN"}), + "batch_size_min": ( + "INT", + { + "default": 1, + "min": 1, + "max": 100, + "step": 1, + }, + ), + "batch_size_opt": ( + "INT", + { + "default": 1, + "min": 1, + "max": 100, + "step": 1, + }, + ), + "batch_size_max": ( + "INT", + { + "default": 1, + "min": 1, + "max": 100, + "step": 1, + }, + ), + "height_min": ( + "INT", + { + "default": 512, + "min": 256, + "max": 4096, + "step": 64, + }, + ), + "height_opt": ( + "INT", + { + "default": 512, + "min": 256, + "max": 4096, + "step": 64, + }, + ), + "height_max": ( + "INT", + { + "default": 512, + "min": 256, + "max": 4096, + "step": 64, + }, + ), + "width_min": ( + "INT", + { + "default": 512, + "min": 256, + "max": 4096, + "step": 64, + }, + ), + "width_opt": ( + "INT", + { + "default": 512, + "min": 256, + "max": 4096, + "step": 64, + }, + ), + "width_max": ( + "INT", + { + "default": 512, + "min": 256, + "max": 4096, + "step": 64, + }, + ), + "context_min": ( + "INT", + { + "default": 1, + "min": 1, + "max": 128, + "step": 1, + }, + ), + "context_opt": ( + "INT", + { + "default": 1, + "min": 1, + "max": 128, + "step": 1, + }, + ), + "context_max": ( + "INT", + { + "default": 1, + "min": 1, + "max": 128, + "step": 1, + }, + ), + "num_video_frames": ( + "INT", + { + "default": 14, + "min": 0, + "max": 1000, + "step": 1, + }, + ), + }, + "optional": { + "onnx_model_path": ("STRING", {"default": "", "forceInput": True}), + }, + } + + def convert( + self, + model, + filename_prefix, + batch_size_min, + batch_size_opt, + batch_size_max, + height_min, + height_opt, + height_max, + width_min, + width_opt, + width_max, + context_min, + context_opt, + context_max, + num_video_frames, + onnx_model_path, + ): + return super()._convert( + model, + filename_prefix, + batch_size_min, + batch_size_opt, + batch_size_max, + height_min, + height_opt, + height_max, + width_min, + width_opt, + width_max, + context_min, + context_opt, + context_max, + num_video_frames, + is_static=False, + output_onnx=onnx_model_path, + ) + + +class StaticTRTBuild(TRTBuildBase): + def __init__(self): + super(StaticTRTBuild, self).__init__() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_STAT"}), + "batch_size_opt": ( + "INT", + { + "default": 1, + "min": 1, + "max": 100, + "step": 1, + }, + ), + "height_opt": ( + "INT", + { + "default": 512, + "min": 256, + "max": 4096, + "step": 64, + }, + ), + "width_opt": ( + "INT", + { + "default": 512, + "min": 256, + "max": 4096, + "step": 64, + }, + ), + "context_opt": ( + "INT", + { + "default": 1, + "min": 1, + "max": 128, + "step": 1, + }, + ), + "num_video_frames": ( + "INT", + { + "default": 14, + "min": 0, + "max": 1000, + "step": 1, + }, + ), + }, + "optional": { + "onnx_model_path": ("STRING", {"default": "", "forceInput": True}), + }, + } + + def convert( + self, + model, + filename_prefix, + batch_size_opt, + height_opt, + width_opt, + context_opt, + num_video_frames, + onnx_model_path, + ): + return super()._convert( + model, + filename_prefix, + batch_size_opt, + batch_size_opt, + batch_size_opt, + height_opt, + height_opt, + height_opt, + width_opt, + width_opt, + width_opt, + context_opt, + context_opt, + context_opt, + num_video_frames, + is_static=True, + output_onnx=onnx_model_path, + ) + + +NODE_CLASS_MAPPINGS = { + "TensorRTLoader": TensorRTLoader, + "DYNAMIC_TRT_MODEL_CONVERSION": DynamicTRTBuild, + "STATIC_TRT_MODEL_CONVERSION": StaticTRTBuild, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "TensorRTLoader": "TensorRT Loader", + "DYNAMIC_TRT_MODEL_CONVERSION": "DYNAMIC TRT_MODEL CONVERSION", + "STATIC TRT_MODEL CONVERSION": "STATIC_TRT_MODEL_CONVERSION", +}