diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index dcb1b821c349..e6f16f323964 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1976,13 +1976,44 @@ def enable_input_require_grads(self): def make_inputs_require_grads(module, input, output): output.requires_grad_(True) - self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + hooks = [] + seen_modules = set() + + for module in self.modules(): + if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")): + continue + + input_embeddings = module.get_input_embeddings() + + if input_embeddings is None: + continue + + embedding_id = id(input_embeddings) + if embedding_id in seen_modules: + continue + + seen_modules.add(embedding_id) + hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads)) + + self._require_grads_hooks = hooks + if hooks: + # for BC + self._require_grads_hook = hooks[0] def disable_input_require_grads(self): """ Removes the `_require_grads_hook`. """ - self._require_grads_hook.remove() + hooks = getattr(self, "_require_grads_hooks", None) + if not hooks: + return + + for hook in hooks: + hook.remove() + + self._require_grads_hooks = [] + if hasattr(self, "_require_grads_hook"): + del self._require_grads_hook def get_encoder(self, modality: Optional[str] = None): """ diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index d38112f1fbf8..e133058d794e 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -888,10 +888,12 @@ def __init__(self, config: Blip2QFormerConfig): self.post_init() def get_input_embeddings(self): - return self.embeddings.word_embeddings + # The Q-Former operates on embeddings provided by upstream modules (e.g. query tokens or text embeddings). + # It does not own input embeddings itself, so we return `None` to signal that there is nothing to update. + return None def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value + raise NotImplementedError("Blip2QFormerModel does not own input embeddings and cannot set them.") def get_extended_attention_mask( self, diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 214dcef45081..1a394de39153 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -787,36 +787,6 @@ def __init__(self, config: Idefics2Config): self.post_init() - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. - - This is useful for lora when using gradient checkpointing. - c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 - - Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. - """ - - def get_lowest_module(module): - if len(list(module.children())) == 0: - # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) - return module - else: - # Recursively call the function on each child module - return get_lowest_module(list(module.children())[0]) - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - def get_input_embeddings(self): return self.text_model.get_input_embeddings() @@ -1009,24 +979,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping - the model weights fixed. - """ - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - def get_input_embeddings(self): return self.model.text_model.get_input_embeddings() diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 996f5573115c..9efb9287d9c6 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -530,38 +530,6 @@ def __init__(self, config: Idefics3Config): self.post_init() - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. - - This is useful for lora when using gradient checkpointing. - c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 - - Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. - """ - - def get_lowest_module(module): - if len(list(module.children())) == 0: - # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) - return module - else: - # Recursively call the function on each child module - return get_lowest_module(list(module.children())[0]) - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( - make_inputs_require_grads - ) - - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings def get_input_embeddings(self): return self.text_model.get_input_embeddings() @@ -765,26 +733,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.enable_input_require_grads - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping - the model weights fixed. - """ - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( - make_inputs_require_grads - ) - - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings def get_input_embeddings(self): return self.model.text_model.get_input_embeddings() diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 9b7070ce8bfc..ef5e846c83fc 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1076,6 +1076,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): config: Qwen2_5OmniVisionEncoderConfig _no_split_modules = ["Qwen2_5OmniVisionBlock"] + _input_embed_layer = "patch_embed" input_modalities = ("image", "video") def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None: diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 44c71c5eb0fa..2d179285d5f9 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1920,6 +1920,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): config: Qwen2_5OmniVisionEncoderConfig input_modalities = ("image", "video") _no_split_modules = ["Qwen2_5OmniVisionBlock"] + _input_embed_layer = "patch_embed" def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index c0dd5b983cd6..e5d1ca7e1943 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -307,6 +307,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): config: Qwen2_5_VLVisionConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] + _input_embed_layer = "patch_embed" def __init__(self, config, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 1331c9e70e12..a8d88a73d9bc 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -179,6 +179,7 @@ class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): config: Qwen2_5_VLVisionConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] + _input_embed_layer = "patch_embed" def __init__(self, config, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 593a160ec799..cbe2848d8f7b 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -671,6 +671,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): config: Qwen2VLVisionConfig input_modalities = ("image", "video") _no_split_modules = ["Qwen2VLVisionBlock"] + _input_embed_layer = "patch_embed" def __init__(self, config) -> None: super().__init__(config) diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 636dab11c9d0..9af6386e1e93 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -686,10 +686,10 @@ def _freeze_parameters(self): self._requires_grad = False def get_input_embeddings(self) -> nn.Module: - return self.conv1 + return self.conv2d1 - def set_input_embeddings(self, value: nn.Module): - self.conv1 = value + def set_input_embeddings(self, value): + self.conv2d1 = value def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index d061aaf5e321..9883fef8abf0 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1194,6 +1194,12 @@ def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig): self.n_window_infer = self.config.n_window_infer self.conv_chunksize = self.config.conv_chunksize + def get_input_embeddings(self): + return self.conv2d1 + + def set_input_embeddings(self, value): + self.conv2d1 = value + def forward( self, input_features, diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 02a080385aa6..e85048037c33 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -470,36 +470,6 @@ def __init__(self, config: SmolVLMConfig): self.post_init() - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. - - This is useful for lora when using gradient checkpointing. - c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 - - Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. - """ - - def get_lowest_module(module): - if len(list(module.children())) == 0: - # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) - return module - else: - # Recursively call the function on each child module - return get_lowest_module(list(module.children())[0]) - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - def get_input_embeddings(self): return self.text_model.get_input_embeddings() @@ -748,24 +718,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping - the model weights fixed. - """ - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - def get_input_embeddings(self): return self.model.text_model.get_input_embeddings() diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index e7dca5068ebc..f3689acebcc5 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -150,6 +150,13 @@ def __init__(self, config: TimmWrapperConfig): self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs) self.post_init() + def get_input_embeddings(self): + # Vision backbones from timm do not expose token embeddings, so there is nothing to return. + return None + + def set_input_embeddings(self, value): + raise NotImplementedError("TimmWrapperModel does not own token embeddings and cannot set them.") + @auto_docstring def forward( self, diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 3ac771067dd2..2570b34cf0eb 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -478,23 +478,34 @@ def test_inputs_embeds(self): with torch.no_grad(): model(**inputs)[0] - @unittest.skip("Bart no longer always uses self.shared so not working.") def test_input_embeddings_support_forward_hook(self): # Make sure that registering hooks on the input embeddings are indeed called # in forward. This is necessary for gradient checkpointing in PEFT, see also #41821. + # For BART with tied embeddings, encoder and decoder have separate embedding modules, + # so we need to check that hooks on those modules are called during forward. config, inputs_dict = self.model_tester.prepare_config_and_inputs() for model_class in self.all_model_classes: model = model_class(config) model.to(torch_device) model.eval() - hook = unittest.mock.MagicMock(return_value=None) - model.get_input_embeddings().register_forward_hook(hook) + hooks = [] + base_model = model.model if hasattr(model, "model") else model + + if hasattr(base_model, "encoder") and hasattr(base_model.encoder, "embed_tokens"): + hook = unittest.mock.MagicMock(return_value=None) + base_model.encoder.embed_tokens.register_forward_hook(hook) + hooks.append(hook) + if hasattr(base_model, "decoder") and hasattr(base_model.decoder, "embed_tokens"): + hook = unittest.mock.MagicMock(return_value=None) + base_model.decoder.embed_tokens.register_forward_hook(hook) + hooks.append(hook) inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) model(**inputs) - self.assertGreater(hook.call_count, 0) + total_calls = sum(hook.call_count for hook in hooks) + self.assertGreater(total_calls, 0, f"Hooks on embeddings were not called for {model_class.__name__}") @require_torch_fp16 def test_generate_fp16(self): diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index cdbe65bba5cf..5f90620ca703 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -362,6 +362,62 @@ def test_sdpa_can_dispatch_on_flash(self): def test_multi_gpu_data_parallel_forward(self): pass + def test_enable_input_require_grads_with_gradient_checkpointing(self): + if not self.model_tester.is_training: + self.skipTest(reason="ModelTester not in training mode") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + model = model_class(config) + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + model.enable_input_require_grads() + model.train() + + for parameter in model.parameters(): + parameter.requires_grad = False + + vision_module = None + if hasattr(model, "visual"): + vision_module = model.visual + elif hasattr(model, "model") and hasattr(model.model, "visual"): + vision_module = model.model.visual + + if vision_module is None: + continue + + target_linear = vision_module.blocks[0].attn.qkv + target_linear.weight.requires_grad = True + if target_linear.bias is not None: + target_linear.bias.requires_grad = True + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + outputs = model(**inputs) + + if hasattr(outputs, "loss") and outputs.loss is not None: + loss = outputs.loss + else: + logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] + loss = logits.sum() + + loss.backward() + + self.assertIsNotNone( + target_linear.weight.grad, + f"qkv weights should receive gradients when enable_input_require_grads is used with gradient checkpointing. Model: {model_class.__name__}", + ) + self.assertGreater( + target_linear.weight.grad.abs().sum().item(), + 0, + f"qkv weights should have non-zero gradients when enable_input_require_grads is used with gradient checkpointing. Model: {model_class.__name__}", + ) + @require_torch class Qwen2VLIntegrationTest(unittest.TestCase):