2424 {".engine" },
2525 )
2626
27+
2728class TQDMProgressMonitor (trt .IProgressMonitor ):
2829 def __init__ (self ):
2930 trt .IProgressMonitor .__init__ (self )
@@ -93,14 +94,18 @@ def step_complete(self, phase_name, step):
9394 except KeyboardInterrupt :
9495 # There is no need to propagate this exception to TensorRT. We can simply cancel the build.
9596 return False
96-
97+
9798
9899class TRT_MODEL_CONVERSION_BASE :
99100 def __init__ (self ):
100101 self .output_dir = folder_paths .get_output_directory ()
101102 self .temp_dir = folder_paths .get_temp_directory ()
102103 self .timing_cache_path = os .path .normpath (
103- os .path .join (os .path .join (os .path .dirname (os .path .realpath (__file__ )), "timing_cache.trt" ))
104+ os .path .join (
105+ os .path .join (
106+ os .path .dirname (os .path .realpath (__file__ )), "timing_cache.trt"
107+ )
108+ )
104109 )
105110
106111 RETURN_TYPES = ()
@@ -148,26 +153,30 @@ def _convert(
148153 context_max ,
149154 num_video_frames ,
150155 is_static : bool ,
156+ reuse_model : bool = False ,
151157 ):
152158 output_onnx = os .path .normpath (
153- os .path .join (
154- os .path .join (self .temp_dir , "{}" .format (time .time ())), "model.onnx"
155- )
159+ os .path .join (self .temp_dir , str (time .time ()), "model.onnx" )
156160 )
157161
158- comfy .model_management .unload_all_models ()
159- comfy .model_management .load_models_gpu ([model ], force_patch_weights = True )
162+ if not reuse_model :
163+ comfy .model_management .unload_all_models ()
164+ comfy .model_management .load_models_gpu ([model ], force_patch_weights = True )
160165 unet = model .model .diffusion_model
161166
162167 context_dim = model .model .model_config .unet_config .get ("context_dim" , None )
163168 context_len = 77
164169 context_len_min = context_len
165170
166- if context_dim is None : #SD3
167- context_embedder_config = model .model .model_config .unet_config .get ("context_embedder_config" , None )
171+ if context_dim is None : # SD3
172+ context_embedder_config = model .model .model_config .unet_config .get (
173+ "context_embedder_config" , None
174+ )
168175 if context_embedder_config is not None :
169- context_dim = context_embedder_config .get ("params" , {}).get ("in_features" , None )
170- context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
176+ context_dim = context_embedder_config .get ("params" , {}).get (
177+ "in_features" , None
178+ )
179+ context_len = 154 # NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
171180
172181 if context_dim is not None :
173182 input_names = ["x" , "timesteps" , "context" ]
@@ -179,7 +188,7 @@ def _convert(
179188 "context" : {0 : "batch" , 1 : "num_embeds" },
180189 }
181190
182- transformer_options = model .model_options [' transformer_options' ].copy ()
191+ transformer_options = model .model_options [" transformer_options" ].copy ()
183192 if model .model .model_config .unet_config .get (
184193 "use_temporal_resblock" , False
185194 ): # SVD
@@ -205,7 +214,13 @@ def forward(self, x, timesteps, context, y):
205214 unet = svd_unet
206215 context_len_min = context_len = 1
207216 else :
217+
208218 class UNET (torch .nn .Module ):
219+ def __init__ (self , unet , opts ):
220+ super ().__init__ ()
221+ self .unet = unet
222+ self .transformer_options = opts
223+
209224 def forward (self , x , timesteps , context , y = None ):
210225 return self .unet (
211226 x ,
@@ -214,10 +229,8 @@ def forward(self, x, timesteps, context, y=None):
214229 y ,
215230 transformer_options = self .transformer_options ,
216231 )
217- _unet = UNET ()
218- _unet .unet = unet
219- _unet .transformer_options = transformer_options
220- unet = _unet
232+
233+ unet = UNET (unet , transformer_options )
221234
222235 input_channels = model .model .model_config .unet_config .get ("in_channels" )
223236
@@ -252,7 +265,7 @@ def forward(self, x, timesteps, context, y=None):
252265 torch .zeros (
253266 shape ,
254267 device = comfy .model_management .get_torch_device (),
255- dtype = torch .float16 ,
268+ dtype = torch .bfloat16 ,
256269 ),
257270 )
258271
@@ -272,8 +285,9 @@ def forward(self, x, timesteps, context, y=None):
272285 dynamic_axes = dynamic_axes ,
273286 )
274287
275- comfy .model_management .unload_all_models ()
276- comfy .model_management .soft_empty_cache ()
288+ if not reuse_model :
289+ comfy .model_management .unload_all_models ()
290+ comfy .model_management .soft_empty_cache ()
277291
278292 # TRT conversion starts here
279293 logger = trt .Logger (trt .Logger .INFO )
@@ -304,12 +318,14 @@ def forward(self, x, timesteps, context, y=None):
304318 profile .set_shape (input_names [k ], min_shape , opt_shape , max_shape )
305319
306320 # Encode shapes to filename
307- encode = lambda a : "." .join (map (lambda x : str (x ), a ))
321+ def encode (a ):
322+ return "." .join (map (str , a ))
323+
308324 prefix_encode += "{}#{}#{}#{};" .format (
309325 input_names [k ], encode (min_shape ), encode (opt_shape ), encode (max_shape )
310326 )
311327
312- config .set_flag (trt .BuilderFlag .FP16 )
328+ config .set_flag (trt .BuilderFlag .BF16 )
313329 config .add_optimization_profile (profile )
314330
315331 if is_static :
@@ -589,6 +605,7 @@ def INPUT_TYPES(s):
589605 "step" : 1 ,
590606 },
591607 ),
608+ "reuse_model" : ("BOOLEAN" , {"default" : False }),
592609 },
593610 }
594611
@@ -601,6 +618,7 @@ def convert(
601618 width_opt ,
602619 context_opt ,
603620 num_video_frames ,
621+ reuse_model ,
604622 ):
605623 return super ()._convert (
606624 model ,
@@ -619,6 +637,7 @@ def convert(
619637 context_opt ,
620638 num_video_frames ,
621639 is_static = True ,
640+ reuse_model = reuse_model ,
622641 )
623642
624643
0 commit comments