diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..82624a5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +*.trt \ No newline at end of file diff --git a/models/sd_unet.py b/models/sd_unet.py index d7d42a2..d612330 100644 --- a/models/sd_unet.py +++ b/models/sd_unet.py @@ -192,6 +192,10 @@ def __init__( **kwargs, ) + @classmethod + def from_model(cls, model, **kwargs): + return super(SD21UnclipL_TRT, cls).from_model(model, use_control=True) + class SD21UnclipH_TRT(UNetTRT): def __init__( @@ -214,6 +218,10 @@ def __init__( **kwargs, ) + @classmethod + def from_model(cls, model, **kwargs): + return super(SD21UnclipH_TRT, cls).from_model(model, use_control=True) + class SDXLRefiner_TRT(UNetTRT): def __init__( diff --git a/tensorrt_nodes.py b/tensorrt_nodes.py index b41be31..5d7c89a 100644 --- a/tensorrt_nodes.py +++ b/tensorrt_nodes.py @@ -155,9 +155,7 @@ def _convert( full_output_folder, f"{filename}_{counter:05}_.engine" ) - batch_multiplier = ( - 2 if model_helper.is_conditional else 1 - ) # TODO lets see if we really want this + batch_multiplier = 1 if model_version == "SVD_img2vid": batch_multiplier *= num_video_frames success = trt_model.build( @@ -338,7 +336,7 @@ def convert( context_opt, context_max, num_video_frames, - onnx_model_path, + onnx_model_path = None, ): return super()._convert( model, @@ -431,7 +429,7 @@ def convert( width_opt, context_opt, num_video_frames, - onnx_model_path, + onnx_model_path = None, ): return super()._convert( model,