diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index b18977720cf8..5a382c1c9423 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -315,6 +315,8 @@ pipeline.load_lora_weights( > [!TIP] > Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example. +If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details. + There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs. ## Merge diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index dcc7ae16a44e..def81ecd648f 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1350,7 +1350,6 @@ def test_model_parallelism(self): new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) - print(f" new_model.hf_device_map:{new_model.hf_device_map}") self.check_device_map_is_respected(new_model, new_model.hf_device_map) @@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin: """ + different_shapes_for_compilation = None + def tearDown(self): # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, # there will be recompilation errors, as torch caches the model when run in the same process. @@ -2056,11 +2057,13 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ - hotswap the second adapter - check that the outputs are correct - optionally compile the model + - optionally check if recompilations happen on different shapes Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is fine. """ + different_shapes = self.different_shapes_for_compilation # create 2 adapters with different ranks and alphas torch.manual_seed(0) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -2110,19 +2113,30 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) if do_compile: - model = torch.compile(model, mode="reduce-overhead") + model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None) with torch.inference_mode(): - output0_after = model(**inputs_dict)["sample"] - assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) + # additionally check if dynamic compilation works. + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output0_after = model(**inputs_dict)["sample"] + assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) # hotswap the 2nd adapter model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) # we need to call forward to potentially trigger recompilation with torch.inference_mode(): - output1_after = model(**inputs_dict)["sample"] - assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output1_after = model(**inputs_dict)["sample"] + assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) # check error when not passing valid adapter name name = "does-not-exist" @@ -2240,3 +2254,23 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self): do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1 ) assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) + @require_torch_version_greater("2.7.1") + def test_hotswapping_compile_on_different_shapes(self, rank0, rank1): + different_shapes_for_compilation = self.different_shapes_for_compilation + if different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic + # variable to represent input sizes that are the same. For more details, + # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). + torch.fx.experimental._config.use_duck_shape = False + + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_model_hotswap( + do_compile=True, + rank0=rank0, + rank1=rank1, + target_modules0=target_modules, + ) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 4552b2e1f5cf..68b5c02bc0b0 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -186,6 +186,10 @@ def prepare_dummy_input(self, height, width): class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] def prepare_init_args_and_inputs_for_common(self): return FluxTransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return FluxTransformerTests().prepare_dummy_input(height=height, width=width)