Skip to content

[tests] add test for hotswapping + compilation on resolution changes #11825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/tutorials/using_peft_for_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ pipeline.load_lora_weights(
> [!TIP]
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should publicize use_duck_shape = False as well...


There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.

## Merge
Expand Down
46 changes: 40 additions & 6 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,6 @@ def test_model_parallelism(self):
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated but hopefully okay :-)


self.check_device_map_is_respected(new_model, new_model.hf_device_map)

Expand Down Expand Up @@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin:

"""

different_shapes_for_compilation = None

def tearDown(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
Expand Down Expand Up @@ -2056,11 +2057,13 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
- optionally check if recompilations happen on different shapes

Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand Down Expand Up @@ -2110,19 +2113,30 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)

if do_compile:
model = torch.compile(model, mode="reduce-overhead")
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)

with torch.inference_mode():
output0_after = model(**inputs_dict)["sample"]
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# additionally check if dynamic compilation works.
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output0_after = model(**inputs_dict)["sample"]
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)

# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)

# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
output1_after = model(**inputs_dict)["sample"]
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output1_after = model(**inputs_dict)["sample"]
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)

# check error when not passing valid adapter name
name = "does-not-exist"
Expand Down Expand Up @@ -2240,3 +2254,23 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self):
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
)
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)

@parameterized.expand([(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")
def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
different_shapes_for_compilation = self.different_shapes_for_compilation
if different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
# variable to represent input sizes that are the same. For more details,
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
torch.fx.experimental._config.use_duck_shape = False

target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(
do_compile=True,
rank0=rank0,
rank1=rank1,
target_modules0=target_modules,
)
4 changes: 4 additions & 0 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def prepare_dummy_input(self, height, width):

class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]

def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()

def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)