diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 87e0d2c29e48..03c05a05e094 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -29,8 +29,9 @@ import numpy as np import torch import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger +from accelerate.state import AcceleratorState from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from huggingface_hub.utils import insecure_hashlib @@ -1222,6 +1223,9 @@ def main(args): kwargs_handlers=[kwargs], ) + if accelerator.distributed_type == DistributedType.DEEPSPEED: + AcceleratorState().deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size + # Disable AMP for MPS. if torch.backends.mps.is_available(): accelerator.native_amp = False @@ -1438,17 +1442,20 @@ def save_model_hook(models, weights, output_dir): text_encoder_one_lora_layers_to_save = None modules_to_save = {} for model in models: - if isinstance(model, type(unwrap_model(transformer))): + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) transformer_lora_layers_to_save = get_peft_model_state_dict(model) modules_to_save["transformer"] = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): + model = unwrap_model(model) text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) modules_to_save["text_encoder"] = model else: raise ValueError(f"unexpected save model: {model.__class__}") # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() FluxKontextPipeline.save_lora_weights( output_dir, @@ -1461,15 +1468,25 @@ def load_model_hook(models, input_dir): transformer_ = None text_encoder_one_ = None - while len(models) > 0: - model = models.pop() + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): + text_encoder_one_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + else: + transformer_ = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer" + ) + transformer_.add_adapter(transformer_lora_config) + text_encoder_one_ = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder" + ) lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir) @@ -2069,7 +2086,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: