29
29
import numpy as np
30
30
import torch
31
31
import transformers
32
- from accelerate import Accelerator
32
+ from accelerate import Accelerator , DistributedType
33
33
from accelerate .logging import get_logger
34
+ from accelerate .state import AcceleratorState
34
35
from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
35
36
from huggingface_hub import create_repo , upload_folder
36
37
from huggingface_hub .utils import insecure_hashlib
@@ -1222,6 +1223,9 @@ def main(args):
1222
1223
kwargs_handlers = [kwargs ],
1223
1224
)
1224
1225
1226
+ if accelerator .distributed_type == DistributedType .DEEPSPEED :
1227
+ AcceleratorState ().deepspeed_plugin .deepspeed_config ["train_micro_batch_size_per_gpu" ] = args .train_batch_size
1228
+
1225
1229
# Disable AMP for MPS.
1226
1230
if torch .backends .mps .is_available ():
1227
1231
accelerator .native_amp = False
@@ -1438,17 +1442,20 @@ def save_model_hook(models, weights, output_dir):
1438
1442
text_encoder_one_lora_layers_to_save = None
1439
1443
modules_to_save = {}
1440
1444
for model in models :
1441
- if isinstance (model , type (unwrap_model (transformer ))):
1445
+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1446
+ model = unwrap_model (model )
1442
1447
transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1443
1448
modules_to_save ["transformer" ] = model
1444
- elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1449
+ elif isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_one ))):
1450
+ model = unwrap_model (model )
1445
1451
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict (model )
1446
1452
modules_to_save ["text_encoder" ] = model
1447
1453
else :
1448
1454
raise ValueError (f"unexpected save model: { model .__class__ } " )
1449
1455
1450
1456
# make sure to pop weight so that corresponding model is not saved again
1451
- weights .pop ()
1457
+ if weights :
1458
+ weights .pop ()
1452
1459
1453
1460
FluxKontextPipeline .save_lora_weights (
1454
1461
output_dir ,
@@ -1461,15 +1468,25 @@ def load_model_hook(models, input_dir):
1461
1468
transformer_ = None
1462
1469
text_encoder_one_ = None
1463
1470
1464
- while len (models ) > 0 :
1465
- model = models .pop ()
1471
+ if not accelerator .distributed_type == DistributedType .DEEPSPEED :
1472
+ while len (models ) > 0 :
1473
+ model = models .pop ()
1466
1474
1467
- if isinstance (model , type (unwrap_model (transformer ))):
1468
- transformer_ = model
1469
- elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1470
- text_encoder_one_ = model
1471
- else :
1472
- raise ValueError (f"unexpected save model: { model .__class__ } " )
1475
+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1476
+ transformer_ = unwrap_model (model )
1477
+ elif isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_one ))):
1478
+ text_encoder_one_ = unwrap_model (model )
1479
+ else :
1480
+ raise ValueError (f"unexpected save model: { model .__class__ } " )
1481
+
1482
+ else :
1483
+ transformer_ = FluxTransformer2DModel .from_pretrained (
1484
+ args .pretrained_model_name_or_path , subfolder = "transformer"
1485
+ )
1486
+ transformer_ .add_adapter (transformer_lora_config )
1487
+ text_encoder_one_ = text_encoder_cls_one .from_pretrained (
1488
+ args .pretrained_model_name_or_path , subfolder = "text_encoder"
1489
+ )
1473
1490
1474
1491
lora_state_dict = FluxKontextPipeline .lora_state_dict (input_dir )
1475
1492
@@ -2069,7 +2086,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
2069
2086
progress_bar .update (1 )
2070
2087
global_step += 1
2071
2088
2072
- if accelerator .is_main_process :
2089
+ if accelerator .is_main_process or accelerator . distributed_type == DistributedType . DEEPSPEED :
2073
2090
if global_step % args .checkpointing_steps == 0 :
2074
2091
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
2075
2092
if args .checkpoints_total_limit is not None :
0 commit comments