diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 7d65b30659fb..df2fb1aecaaa 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -22,6 +22,7 @@ import safetensors import torch +from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, @@ -792,6 +793,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) + _maybe_remove_and_reapply_group_offloading(self) + def enable_lora_hotswap( self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error" ) -> None: diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 3d4344bb86a9..10a5c8e9aa1d 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -28,6 +28,7 @@ AutoencoderKL, UNet2DConditionModel, ) +from diffusers.hooks.group_offloading import apply_group_offloading from diffusers.utils import logging from diffusers.utils.import_utils import is_peft_available @@ -2367,3 +2368,43 @@ def test_lora_loading_model_cpu_offload(self): output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3)) + + @require_torch_accelerator + def test_lora_group_offloading_delete_adapters(self): + components, _, denoiser_lora_config = self.get_dummy_components() + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts + ) + + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + + # Enable Group Offloading (leaf_level) + apply_group_offloading( + denoiser, + onload_device=torch_device, + offload_device="cpu", + offload_type="leaf_level", + ) + + pipe.load_lora_weights(tmpdirname, adapter_name="default") + out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + # Delete the adapter + pipe.delete_adapters("default") + out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertFalse(np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3))