From 4da07a7cd5ad5ee5b3204c2f94b46b49d5a0e9cd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Jun 2025 11:44:29 +0530 Subject: [PATCH 1/5] add resolution changes tests to hotswapping test suite. --- tests/models/test_modeling_common.py | 38 ++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index dcc7ae16a44e..7b65a448857a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -2046,7 +2046,9 @@ def get_linear_module_name_other_than_attn(self, model): ] return linear_names[0] - def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None): + def check_model_hotswap( + self, do_compile, rank0, rank1, target_modules0, target_modules1=None, different_resolutions=None + ): """ Check that hotswapping works on a small unet. @@ -2056,6 +2058,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ - hotswap the second adapter - check that the outputs are correct - optionally compile the model + - optionally check if recompilations happen on different shapes Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is @@ -2110,10 +2113,17 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) if do_compile: - model = torch.compile(model, mode="reduce-overhead") + model = torch.compile(model, mode="reduce-overhead", dynamic=different_resolutions is not None) with torch.inference_mode(): output0_after = model(**inputs_dict)["sample"] + + # additionally check if dynamic compilation works. + if different_resolutions is not None: + for height, width in self.different_shapes_for_compilation: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) # hotswap the 2nd adapter @@ -2122,6 +2132,12 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ # we need to call forward to potentially trigger recompilation with torch.inference_mode(): output1_after = model(**inputs_dict)["sample"] + + if different_resolutions is not None: + for height, width in self.different_shapes_for_compilation: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) # check error when not passing valid adapter name @@ -2240,3 +2256,21 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self): do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1 ) assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) + @require_torch_version_greater("2.7.1") + def test_hotswapping_compile_on_different_shapes(self, rank0, rank1): + different_shapes_for_compilation = self.different_shapes_for_compilation + if different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + torch.fx.experimental._config.use_duck_shape = False + + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_model_hotswap( + do_compile=True, + rank0=rank0, + rank1=rank1, + target_modules0=target_modules, + different_resolutions=different_shapes_for_compilation, + ) From 9f1c83fb4c2401c5934947c4a240dccdcc1efd51 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Jun 2025 12:59:19 +0530 Subject: [PATCH 2/5] fixes --- tests/models/test_modeling_common.py | 17 ++++++++--------- .../test_models_transformer_flux.py | 4 ++++ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 7b65a448857a..ec2920850773 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1350,7 +1350,6 @@ def test_model_parallelism(self): new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) # Making sure part of the model will actually end up offloaded self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) - print(f" new_model.hf_device_map:{new_model.hf_device_map}") self.check_device_map_is_respected(new_model, new_model.hf_device_map) @@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin: """ + different_shapes_for_compilation = None + def tearDown(self): # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, # there will be recompilation errors, as torch caches the model when run in the same process. @@ -2116,29 +2117,27 @@ def check_model_hotswap( model = torch.compile(model, mode="reduce-overhead", dynamic=different_resolutions is not None) with torch.inference_mode(): - output0_after = model(**inputs_dict)["sample"] - # additionally check if dynamic compilation works. if different_resolutions is not None: for height, width in self.different_shapes_for_compilation: new_inputs_dict = self.prepare_dummy_input(height=height, width=width) _ = model(**new_inputs_dict) - - assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) + else: + output0_after = model(**inputs_dict)["sample"] + assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) # hotswap the 2nd adapter model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) # we need to call forward to potentially trigger recompilation with torch.inference_mode(): - output1_after = model(**inputs_dict)["sample"] - if different_resolutions is not None: for height, width in self.different_shapes_for_compilation: new_inputs_dict = self.prepare_dummy_input(height=height, width=width) _ = model(**new_inputs_dict) - - assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) + else: + output1_after = model(**inputs_dict)["sample"] + assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) # check error when not passing valid adapter name name = "does-not-exist" diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 4552b2e1f5cf..68b5c02bc0b0 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -186,6 +186,10 @@ def prepare_dummy_input(self, height, width): class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel + different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] def prepare_init_args_and_inputs_for_common(self): return FluxTransformerTests().prepare_init_args_and_inputs_for_common() + + def prepare_dummy_input(self, height, width): + return FluxTransformerTests().prepare_dummy_input(height=height, width=width) From 7fba82c5ac171ca86bd8c41a12061c08afd28583 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 28 Jun 2025 13:05:39 +0530 Subject: [PATCH 3/5] docs --- docs/source/en/tutorials/using_peft_for_inference.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index b18977720cf8..5a382c1c9423 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -315,6 +315,8 @@ pipeline.load_lora_weights( > [!TIP] > Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example. +If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details. + There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs. ## Merge From 2076a5390ffaa55766940623158c674583f7f9d9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 30 Jun 2025 15:00:37 +0530 Subject: [PATCH 4/5] explain duck shapes --- tests/models/test_modeling_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index ec2920850773..a3ee2b07d8f1 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -2262,6 +2262,9 @@ def test_hotswapping_compile_on_different_shapes(self, rank0, rank1): different_shapes_for_compilation = self.different_shapes_for_compilation if different_shapes_for_compilation is None: pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic + # variable to represent input sizes that are the same. For more details, + # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). torch.fx.experimental._config.use_duck_shape = False target_modules = ["to_q", "to_k", "to_v", "to_out.0"] From 579fb768e1970c7af2bbb219ef3440a4060d493a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 30 Jun 2025 19:17:03 +0530 Subject: [PATCH 5/5] fix --- tests/models/test_modeling_common.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index a3ee2b07d8f1..def81ecd648f 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -2047,9 +2047,7 @@ def get_linear_module_name_other_than_attn(self, model): ] return linear_names[0] - def check_model_hotswap( - self, do_compile, rank0, rank1, target_modules0, target_modules1=None, different_resolutions=None - ): + def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None): """ Check that hotswapping works on a small unet. @@ -2065,6 +2063,7 @@ def check_model_hotswap( fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is fine. """ + different_shapes = self.different_shapes_for_compilation # create 2 adapters with different ranks and alphas torch.manual_seed(0) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -2114,12 +2113,12 @@ def check_model_hotswap( model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) if do_compile: - model = torch.compile(model, mode="reduce-overhead", dynamic=different_resolutions is not None) + model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None) with torch.inference_mode(): # additionally check if dynamic compilation works. - if different_resolutions is not None: - for height, width in self.different_shapes_for_compilation: + if different_shapes is not None: + for height, width in different_shapes: new_inputs_dict = self.prepare_dummy_input(height=height, width=width) _ = model(**new_inputs_dict) else: @@ -2131,8 +2130,8 @@ def check_model_hotswap( # we need to call forward to potentially trigger recompilation with torch.inference_mode(): - if different_resolutions is not None: - for height, width in self.different_shapes_for_compilation: + if different_shapes is not None: + for height, width in different_shapes: new_inputs_dict = self.prepare_dummy_input(height=height, width=width) _ = model(**new_inputs_dict) else: @@ -2274,5 +2273,4 @@ def test_hotswapping_compile_on_different_shapes(self, rank0, rank1): rank0=rank0, rank1=rank1, target_modules0=target_modules, - different_resolutions=different_shapes_for_compilation, )