Skip to content

Commit ed282ab

Browse files
committed
fp16 -> bf16, add reuse_model option, code refactor
1 parent 6bf3b54 commit ed282ab

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

tensorrt_convert.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
{".engine"},
2525
)
2626

27+
2728
class TQDMProgressMonitor(trt.IProgressMonitor):
2829
def __init__(self):
2930
trt.IProgressMonitor.__init__(self)
@@ -93,14 +94,18 @@ def step_complete(self, phase_name, step):
9394
except KeyboardInterrupt:
9495
# There is no need to propagate this exception to TensorRT. We can simply cancel the build.
9596
return False
96-
97+
9798

9899
class TRT_MODEL_CONVERSION_BASE:
99100
def __init__(self):
100101
self.output_dir = folder_paths.get_output_directory()
101102
self.temp_dir = folder_paths.get_temp_directory()
102103
self.timing_cache_path = os.path.normpath(
103-
os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"))
104+
os.path.join(
105+
os.path.join(
106+
os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"
107+
)
108+
)
104109
)
105110

106111
RETURN_TYPES = ()
@@ -148,26 +153,30 @@ def _convert(
148153
context_max,
149154
num_video_frames,
150155
is_static: bool,
156+
reuse_model: bool = False,
151157
):
152158
output_onnx = os.path.normpath(
153-
os.path.join(
154-
os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx"
155-
)
159+
os.path.join(self.temp_dir, str(time.time()), "model.onnx")
156160
)
157161

158-
comfy.model_management.unload_all_models()
159-
comfy.model_management.load_models_gpu([model], force_patch_weights=True)
162+
if not reuse_model:
163+
comfy.model_management.unload_all_models()
164+
comfy.model_management.load_models_gpu([model], force_patch_weights=True)
160165
unet = model.model.diffusion_model
161166

162167
context_dim = model.model.model_config.unet_config.get("context_dim", None)
163168
context_len = 77
164169
context_len_min = context_len
165170

166-
if context_dim is None: #SD3
167-
context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None)
171+
if context_dim is None: # SD3
172+
context_embedder_config = model.model.model_config.unet_config.get(
173+
"context_embedder_config", None
174+
)
168175
if context_embedder_config is not None:
169-
context_dim = context_embedder_config.get("params", {}).get("in_features", None)
170-
context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
176+
context_dim = context_embedder_config.get("params", {}).get(
177+
"in_features", None
178+
)
179+
context_len = 154 # NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
171180

172181
if context_dim is not None:
173182
input_names = ["x", "timesteps", "context"]
@@ -179,7 +188,7 @@ def _convert(
179188
"context": {0: "batch", 1: "num_embeds"},
180189
}
181190

182-
transformer_options = model.model_options['transformer_options'].copy()
191+
transformer_options = model.model_options["transformer_options"].copy()
183192
if model.model.model_config.unet_config.get(
184193
"use_temporal_resblock", False
185194
): # SVD
@@ -205,7 +214,13 @@ def forward(self, x, timesteps, context, y):
205214
unet = svd_unet
206215
context_len_min = context_len = 1
207216
else:
217+
208218
class UNET(torch.nn.Module):
219+
def __init__(self, unet, opts):
220+
super().__init__()
221+
self.unet = unet
222+
self.transformer_options = opts
223+
209224
def forward(self, x, timesteps, context, y=None):
210225
return self.unet(
211226
x,
@@ -214,10 +229,8 @@ def forward(self, x, timesteps, context, y=None):
214229
y,
215230
transformer_options=self.transformer_options,
216231
)
217-
_unet = UNET()
218-
_unet.unet = unet
219-
_unet.transformer_options = transformer_options
220-
unet = _unet
232+
233+
unet = UNET(unet, transformer_options)
221234

222235
input_channels = model.model.model_config.unet_config.get("in_channels")
223236

@@ -252,7 +265,7 @@ def forward(self, x, timesteps, context, y=None):
252265
torch.zeros(
253266
shape,
254267
device=comfy.model_management.get_torch_device(),
255-
dtype=torch.float16,
268+
dtype=torch.bfloat16,
256269
),
257270
)
258271

@@ -272,8 +285,9 @@ def forward(self, x, timesteps, context, y=None):
272285
dynamic_axes=dynamic_axes,
273286
)
274287

275-
comfy.model_management.unload_all_models()
276-
comfy.model_management.soft_empty_cache()
288+
if not reuse_model:
289+
comfy.model_management.unload_all_models()
290+
comfy.model_management.soft_empty_cache()
277291

278292
# TRT conversion starts here
279293
logger = trt.Logger(trt.Logger.INFO)
@@ -304,12 +318,14 @@ def forward(self, x, timesteps, context, y=None):
304318
profile.set_shape(input_names[k], min_shape, opt_shape, max_shape)
305319

306320
# Encode shapes to filename
307-
encode = lambda a: ".".join(map(lambda x: str(x), a))
321+
def encode(a):
322+
return ".".join(map(str, a))
323+
308324
prefix_encode += "{}#{}#{}#{};".format(
309325
input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape)
310326
)
311327

312-
config.set_flag(trt.BuilderFlag.FP16)
328+
config.set_flag(trt.BuilderFlag.BF16)
313329
config.add_optimization_profile(profile)
314330

315331
if is_static:
@@ -589,6 +605,7 @@ def INPUT_TYPES(s):
589605
"step": 1,
590606
},
591607
),
608+
"reuse_model": ("BOOLEAN", {"default": False}),
592609
},
593610
}
594611

@@ -601,6 +618,7 @@ def convert(
601618
width_opt,
602619
context_opt,
603620
num_video_frames,
621+
reuse_model,
604622
):
605623
return super()._convert(
606624
model,
@@ -619,6 +637,7 @@ def convert(
619637
context_opt,
620638
num_video_frames,
621639
is_static=True,
640+
reuse_model=reuse_model,
622641
)
623642

624643

0 commit comments

Comments
 (0)