Skip to content

Commit a752a2f

Browse files
AuraFlow support.
1 parent a9a6923 commit a752a2f

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

tensorrt_convert.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from tqdm import tqdm
1010

1111
# TODO:
12-
# Deal with xformers if it's enabled
1312
# Make it more generic: less model specific code
1413

1514
# add output directory to tensorrt search path
@@ -163,11 +162,15 @@ def _convert(
163162
context_len = 77
164163
context_len_min = context_len
165164

166-
if context_dim is None: #SD3
165+
if isinstance(model.model, comfy.model_base.SD3): #SD3
167166
context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None)
168167
if context_embedder_config is not None:
169168
context_dim = context_embedder_config.get("params", {}).get("in_features", None)
170169
context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
170+
elif isinstance(model.model, comfy.model_base.AuraFlow):
171+
context_dim = 2048
172+
context_len_min = 256
173+
context_len = 256
171174

172175
if context_dim is not None:
173176
input_names = ["x", "timesteps", "context"]
@@ -207,19 +210,22 @@ def forward(self, x, timesteps, context, y):
207210
else:
208211
class UNET(torch.nn.Module):
209212
def forward(self, x, timesteps, context, y=None):
210-
return self.unet(
211-
x,
212-
timesteps,
213-
context,
214-
y,
215-
transformer_options=self.transformer_options,
216-
)
213+
if y is None:
214+
return self.unet(x, timesteps, context, transformer_options=self.transformer_options)
215+
else:
216+
return self.unet(
217+
x,
218+
timesteps,
219+
context,
220+
y,
221+
transformer_options=self.transformer_options,
222+
)
217223
_unet = UNET()
218224
_unet.unet = unet
219225
_unet.transformer_options = transformer_options
220226
unet = _unet
221227

222-
input_channels = model.model.model_config.unet_config.get("in_channels")
228+
input_channels = model.model.model_config.unet_config.get("in_channels", 4)
223229

224230
inputs_shapes_min = (
225231
(batch_size_min, input_channels, height_min // 8, width_min // 8),

tensorrt_loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class TensorRTLoader:
110110
@classmethod
111111
def INPUT_TYPES(s):
112112
return {"required": {"unet_name": (folder_paths.get_filename_list("tensorrt"), ),
113-
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3"], ),
113+
"model_type": (["sdxl_base", "sdxl_refiner", "sd1.x", "sd2.x-768v", "svd", "sd3", "auraflow"], ),
114114
}}
115115
RETURN_TYPES = ("MODEL",)
116116
FUNCTION = "load_unet"
@@ -146,6 +146,10 @@ def load_unet(self, unet_name, model_type):
146146
conf = comfy.supported_models.SD3({})
147147
conf.unet_config["disable_unet_model_creation"] = True
148148
model = conf.get_model({})
149+
elif model_type == "auraflow":
150+
conf = comfy.supported_models.AuraFlow({})
151+
conf.unet_config["disable_unet_model_creation"] = True
152+
model = conf.get_model({})
149153
model.diffusion_model = unet
150154
model.memory_required = lambda *args, **kwargs: 0 #always pass inputs batched up as much as possible, our TRT code will handle batch splitting
151155

0 commit comments

Comments
 (0)