|
9 | 9 | from tqdm import tqdm |
10 | 10 |
|
11 | 11 | # TODO: |
12 | | -# Deal with xformers if it's enabled |
13 | 12 | # Make it more generic: less model specific code |
14 | 13 |
|
15 | 14 | # add output directory to tensorrt search path |
@@ -163,11 +162,15 @@ def _convert( |
163 | 162 | context_len = 77 |
164 | 163 | context_len_min = context_len |
165 | 164 |
|
166 | | - if context_dim is None: #SD3 |
| 165 | + if isinstance(model.model, comfy.model_base.SD3): #SD3 |
167 | 166 | context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None) |
168 | 167 | if context_embedder_config is not None: |
169 | 168 | context_dim = context_embedder_config.get("params", {}).get("in_features", None) |
170 | 169 | context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77 |
| 170 | + elif isinstance(model.model, comfy.model_base.AuraFlow): |
| 171 | + context_dim = 2048 |
| 172 | + context_len_min = 256 |
| 173 | + context_len = 256 |
171 | 174 |
|
172 | 175 | if context_dim is not None: |
173 | 176 | input_names = ["x", "timesteps", "context"] |
@@ -207,19 +210,22 @@ def forward(self, x, timesteps, context, y): |
207 | 210 | else: |
208 | 211 | class UNET(torch.nn.Module): |
209 | 212 | def forward(self, x, timesteps, context, y=None): |
210 | | - return self.unet( |
211 | | - x, |
212 | | - timesteps, |
213 | | - context, |
214 | | - y, |
215 | | - transformer_options=self.transformer_options, |
216 | | - ) |
| 213 | + if y is None: |
| 214 | + return self.unet(x, timesteps, context, transformer_options=self.transformer_options) |
| 215 | + else: |
| 216 | + return self.unet( |
| 217 | + x, |
| 218 | + timesteps, |
| 219 | + context, |
| 220 | + y, |
| 221 | + transformer_options=self.transformer_options, |
| 222 | + ) |
217 | 223 | _unet = UNET() |
218 | 224 | _unet.unet = unet |
219 | 225 | _unet.transformer_options = transformer_options |
220 | 226 | unet = _unet |
221 | 227 |
|
222 | | - input_channels = model.model.model_config.unet_config.get("in_channels") |
| 228 | + input_channels = model.model.model_config.unet_config.get("in_channels", 4) |
223 | 229 |
|
224 | 230 | inputs_shapes_min = ( |
225 | 231 | (batch_size_min, input_channels, height_min // 8, width_min // 8), |
|
0 commit comments