diff --git a/models.py b/models.py index 1355c2b..123f758 100644 --- a/models.py +++ b/models.py @@ -317,7 +317,7 @@ def get_batch_dim(self, min_batch, opt_batch, max_batch, static_batch): return (opt_batch, opt_batch, opt_batch) elif self.text_maxlen > 77 and not static_batch: if self.text_optlen > 77: - return (min_batch, opt_batch, max_batch * 2) + return (min_batch, opt_batch * 2, max_batch * 2) return (min_batch, opt_batch * 2, max_batch * 2) else: raise Exception("Uncovered case in get_batch_dim") @@ -810,6 +810,7 @@ def __init__( verbose=True, fp16=False, max_batch_size=16, + text_minlen=77, text_maxlen=77, text_optlen=77, unet_dim=4, @@ -828,6 +829,7 @@ def __init__( ) self.unet_dim = unet_dim self.controlnet = controlnet + self.text_minlen = text_minlen self.text_optlen = text_optlen def get_input_names(self): @@ -896,7 +898,7 @@ def get_input_profile( ], "timesteps": [(min_batch,), (opt_batch,), (max_batch,)], "encoder_hidden_states": [ - (min_batch, self.text_optlen, self.embedding_dim), + (min_batch, self.text_minlen, self.embedding_dim), (opt_batch, self.text_optlen, self.embedding_dim), (max_batch, self.text_maxlen, self.embedding_dim), ], @@ -1034,6 +1036,7 @@ def make_OAIUNet( device, verbose, max_batch_size, + text_minlen, text_optlen, text_maxlen, controlnet=None, @@ -1045,6 +1048,7 @@ def make_OAIUNet( device=device, verbose=verbose, max_batch_size=max_batch_size, + text_minlen=text_minlen, text_optlen=text_optlen, text_maxlen=text_maxlen, unet_dim=(9 if pipeline.is_inpaint() else 4), diff --git a/ui_trt.py b/ui_trt.py index 4ae84cf..37700b9 100644 --- a/ui_trt.py +++ b/ui_trt.py @@ -109,6 +109,7 @@ def export_unet_to_trt( "cuda", False, batch_max, + min_textlen, opt_textlen, max_textlen, controlnet,