From 61679f844689772a293ed86040e0608326c27760 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Mon, 3 Nov 2025 15:24:00 +0100 Subject: [PATCH 1/7] attempt to fix gradients --- src/transformers/modeling_utils.py | 31 +++++++++++ .../models/qwen2_vl/test_modeling_qwen2_vl.py | 52 +++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 794c238c0fa7..15171200cc28 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1641,6 +1641,12 @@ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings +def _get_lowest_module(module): + if len(list(module.children())) == 0: + return module + return _get_lowest_module(list(module.children())[0]) + + class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin): r""" Base class for all models. @@ -2575,11 +2581,36 @@ def make_inputs_require_grads(module, input, output): self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + vision_module_attributes = [ + "vision_tower", + "vision_model", + "visual", + "vision_encoder", + "image_encoder", + ] + + vision_module = None + for attribute_name in vision_module_attributes: + if hasattr(self, attribute_name): + vision_module = getattr(self, attribute_name) + elif hasattr(self, "model") and hasattr(self.model, attribute_name): + vision_module = getattr(self.model, attribute_name) + + if vision_module is not None: + break + + if vision_module is not None: + self._vision_require_grads_hook = _get_lowest_module(vision_module).register_forward_hook( + make_inputs_require_grads + ) + def disable_input_require_grads(self): """ Removes the `_require_grads_hook`. """ self._require_grads_hook.remove() + if hasattr(self, "_vision_require_grads_hook"): + self._vision_require_grads_hook.remove() def get_decoder(self): """ diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index e06c2872ed5d..e472dc25b392 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -397,6 +397,58 @@ def test_sdpa_can_dispatch_on_flash(self): def test_multi_gpu_data_parallel_forward(self): pass + def test_enable_input_require_grads_with_gradient_checkpointing(self): + if not self.model_tester.is_training: + self.skipTest(reason="ModelTester not in training mode") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + model = model_class(config) + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + model.enable_input_require_grads() + model.train() + + for parameter in model.parameters(): + parameter.requires_grad = False + + vision_module = None + if hasattr(model, "visual"): + vision_module = model.visual + elif hasattr(model, "model") and hasattr(model.model, "visual"): + vision_module = model.model.visual + + if vision_module is not None: + for parameter in vision_module.parameters(): + parameter.requires_grad = True + + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + outputs = model(**inputs) + + if hasattr(outputs, "loss") and outputs.loss is not None: + loss = outputs.loss + else: + logits = outputs.logits if hasattr(outputs, "logits") else outputs[0] + loss = logits.sum() + + loss.backward() + + if vision_module is not None: + vision_params_with_grad = [ + name for name, param in vision_module.named_parameters() if param.grad is not None + ] + self.assertGreater( + len(vision_params_with_grad), + 0, + f"Vision module parameters should have gradients when enable_input_require_grads is used with gradient checkpointing. Model: {model_class.__name__}", + ) + @require_torch class Qwen2VLIntegrationTest(unittest.TestCase): From c1f032988301ccc3ea5c0a5719cb96c258136114 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Fri, 14 Nov 2025 11:18:51 +0100 Subject: [PATCH 2/7] Improve tests, use PreTrainedModel hooks, cleanup --- src/transformers/modeling_utils.py | 66 +++++++++---------- .../models/idefics2/modeling_idefics2.py | 48 -------------- .../models/idefics3/modeling_idefics3.py | 52 --------------- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1 + .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 1 + .../models/qwen2_vl/modeling_qwen2_vl.py | 1 + .../models/smolvlm/modeling_smolvlm.py | 48 -------------- .../models/qwen2_vl/test_modeling_qwen2_vl.py | 28 ++++---- 8 files changed, 52 insertions(+), 193 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c9571f7f5f6b..6bf5ed0f9730 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1299,12 +1299,6 @@ def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings -def _get_lowest_module(module): - if len(list(module.children())) == 0: - return module - return _get_lowest_module(list(module.children())[0]) - - class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin): r""" Base class for all models. @@ -2220,38 +2214,44 @@ def enable_input_require_grads(self): def make_inputs_require_grads(module, input, output): output.requires_grad_(True) - self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - - vision_module_attributes = [ - "vision_tower", - "vision_model", - "visual", - "vision_encoder", - "image_encoder", - ] - - vision_module = None - for attribute_name in vision_module_attributes: - if hasattr(self, attribute_name): - vision_module = getattr(self, attribute_name) - elif hasattr(self, "model") and hasattr(self.model, attribute_name): - vision_module = getattr(self.model, attribute_name) - - if vision_module is not None: - break - - if vision_module is not None: - self._vision_require_grads_hook = _get_lowest_module(vision_module).register_forward_hook( - make_inputs_require_grads - ) + hooks = [] + seen_modules = set() + + for module in self.modules(): + if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")): + continue + + input_embeddings = module.get_input_embeddings() + + if input_embeddings is None: + continue + + embedding_id = id(input_embeddings) + if embedding_id in seen_modules: + continue + + seen_modules.add(embedding_id) + hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads)) + + self._require_grads_hooks = hooks + if hooks: + # for BC + self._require_grads_hook = hooks[0] def disable_input_require_grads(self): """ Removes the `_require_grads_hook`. """ - self._require_grads_hook.remove() - if hasattr(self, "_vision_require_grads_hook"): - self._vision_require_grads_hook.remove() + hooks = getattr(self, "_require_grads_hooks", None) + if not hooks: + return + + for hook in hooks: + hook.remove() + + self._require_grads_hooks = [] + if hasattr(self, "_require_grads_hook"): + del self._require_grads_hook def get_decoder(self): """ diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 2caaf2ab2706..136397963ad2 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -802,36 +802,6 @@ def __init__(self, config: Idefics2Config): self.post_init() - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. - - This is useful for lora when using gradient checkpointing. - c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 - - Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. - """ - - def get_lowest_module(module): - if len(list(module.children())) == 0: - # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) - return module - else: - # Recursively call the function on each child module - return get_lowest_module(list(module.children())[0]) - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - def get_input_embeddings(self): return self.text_model.get_input_embeddings() @@ -1024,24 +994,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping - the model weights fixed. - """ - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - def get_input_embeddings(self): return self.model.text_model.get_input_embeddings() diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 6a57af9d49d8..f982db92cc22 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -549,38 +549,6 @@ def __init__(self, config: Idefics3Config): self.post_init() - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. - - This is useful for lora when using gradient checkpointing. - c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 - - Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. - """ - - def get_lowest_module(module): - if len(list(module.children())) == 0: - # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) - return module - else: - # Recursively call the function on each child module - return get_lowest_module(list(module.children())[0]) - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( - make_inputs_require_grads - ) - - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings def get_input_embeddings(self): return self.text_model.get_input_embeddings() @@ -784,26 +752,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.enable_input_require_grads - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping - the model weights fixed. - """ - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( - make_inputs_require_grads - ) - - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings def get_input_embeddings(self): return self.model.text_model.get_input_embeddings() diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 1a24d18939bb..ff6f75746d58 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -307,6 +307,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): config: Qwen2_5_VLVisionConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] + _input_embed_layer = "patch_embed" def __init__(self, config, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index dc88dd02ee73..c7a5f7ebf15a 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -179,6 +179,7 @@ class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel): class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): config: Qwen2_5_VLVisionConfig _no_split_modules = ["Qwen2_5_VLVisionBlock"] + _input_embed_layer = "patch_embed" def __init__(self, config, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index c1b52ff75f9f..085d2bad37f6 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -671,6 +671,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): config: Qwen2VLVisionConfig input_modalities = ["image", "video"] _no_split_modules = ["Qwen2VLVisionBlock"] + _input_embed_layer = "patch_embed" def __init__(self, config) -> None: super().__init__(config) diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 625e616b5989..a3ae0d151acc 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -509,36 +509,6 @@ def __init__(self, config: SmolVLMConfig): self.post_init() - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. - - This is useful for lora when using gradient checkpointing. - c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 - - Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. - """ - - def get_lowest_module(module): - if len(list(module.children())) == 0: - # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) - return module - else: - # Recursively call the function on each child module - return get_lowest_module(list(module.children())[0]) - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - def get_input_embeddings(self): return self.text_model.get_input_embeddings() @@ -787,24 +757,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def enable_input_require_grads(self): - """ - Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping - the model weights fixed. - """ - - def make_inputs_require_grads(module, input, output): - output.requires_grad_(True) - - self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) - self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( - make_inputs_require_grads - ) - - def disable_input_require_grads(self): - self._text_require_grads_hook.remove() - self._vision_require_grads_hook.remove() - def get_input_embeddings(self): return self.model.text_model.get_input_embeddings() diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index e472dc25b392..21ab16b2dbdd 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -424,9 +424,13 @@ def test_enable_input_require_grads_with_gradient_checkpointing(self): elif hasattr(model, "model") and hasattr(model.model, "visual"): vision_module = model.model.visual - if vision_module is not None: - for parameter in vision_module.parameters(): - parameter.requires_grad = True + if vision_module is None: + continue + + target_linear = vision_module.blocks[0].attn.qkv + target_linear.weight.requires_grad = True + if target_linear.bias is not None: + target_linear.bias.requires_grad = True inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) outputs = model(**inputs) @@ -439,15 +443,15 @@ def test_enable_input_require_grads_with_gradient_checkpointing(self): loss.backward() - if vision_module is not None: - vision_params_with_grad = [ - name for name, param in vision_module.named_parameters() if param.grad is not None - ] - self.assertGreater( - len(vision_params_with_grad), - 0, - f"Vision module parameters should have gradients when enable_input_require_grads is used with gradient checkpointing. Model: {model_class.__name__}", - ) + self.assertIsNotNone( + target_linear.weight.grad, + f"qkv weights should receive gradients when enable_input_require_grads is used with gradient checkpointing. Model: {model_class.__name__}", + ) + self.assertGreater( + target_linear.weight.grad.abs().sum().item(), + 0, + f"qkv weights should have non-zero gradients when enable_input_require_grads is used with gradient checkpointing. Model: {model_class.__name__}", + ) @require_torch From 3e160ed8cf393dd9595bc2ae2372efc174c48cba Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Fri, 14 Nov 2025 11:59:22 +0100 Subject: [PATCH 3/7] missing patch_embed --- src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py | 1 + src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 77bc48a1e19d..504ade96f065 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1075,6 +1075,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): config: Qwen2_5OmniVisionEncoderConfig _no_split_modules = ["Qwen2_5OmniVisionBlock"] + _input_embed_layer = "patch_embed" input_modalities = ["image", "video"] def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None: diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 673da8201fed..d0ee964597d6 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1935,6 +1935,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): config: Qwen2_5OmniVisionEncoderConfig input_modalities = ["image", "video"] _no_split_modules = ["Qwen2_5OmniVisionBlock"] + _input_embed_layer = "patch_embed" def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) From b70f0e10157a1148f33f92b3912e0e69f1ef1a3d Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Mon, 17 Nov 2025 10:46:44 +0100 Subject: [PATCH 4/7] fix arg name --- tests/models/glm46v/test_processor_glm46v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/glm46v/test_processor_glm46v.py b/tests/models/glm46v/test_processor_glm46v.py index 4e9fb4449d67..268f20d89c89 100644 --- a/tests/models/glm46v/test_processor_glm46v.py +++ b/tests/models/glm46v/test_processor_glm46v.py @@ -215,7 +215,7 @@ def test_apply_chat_template_video_frame_sampling(self): add_generation_prompt=True, tokenize=True, return_dict=True, - video_fps=video_fps, + fps=video_fps, ) self.assertTrue(self.videos_input_name in out_dict_with_video) self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8) From 17c0db4d552c47ce67121244e6a5a021c8afd7c9 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Mon, 17 Nov 2025 11:18:49 +0100 Subject: [PATCH 5/7] local revert --- tests/models/glm46v/test_processor_glm46v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/glm46v/test_processor_glm46v.py b/tests/models/glm46v/test_processor_glm46v.py index 268f20d89c89..4e9fb4449d67 100644 --- a/tests/models/glm46v/test_processor_glm46v.py +++ b/tests/models/glm46v/test_processor_glm46v.py @@ -215,7 +215,7 @@ def test_apply_chat_template_video_frame_sampling(self): add_generation_prompt=True, tokenize=True, return_dict=True, - fps=video_fps, + video_fps=video_fps, ) self.assertTrue(self.videos_input_name in out_dict_with_video) self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 8) From 5b07e4375d78400974083a1998d81fb4c60f4105 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Thu, 27 Nov 2025 11:59:22 +0100 Subject: [PATCH 6/7] adapt BART test --- tests/models/bart/test_modeling_bart.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 3ac771067dd2..2570b34cf0eb 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -478,23 +478,34 @@ def test_inputs_embeds(self): with torch.no_grad(): model(**inputs)[0] - @unittest.skip("Bart no longer always uses self.shared so not working.") def test_input_embeddings_support_forward_hook(self): # Make sure that registering hooks on the input embeddings are indeed called # in forward. This is necessary for gradient checkpointing in PEFT, see also #41821. + # For BART with tied embeddings, encoder and decoder have separate embedding modules, + # so we need to check that hooks on those modules are called during forward. config, inputs_dict = self.model_tester.prepare_config_and_inputs() for model_class in self.all_model_classes: model = model_class(config) model.to(torch_device) model.eval() - hook = unittest.mock.MagicMock(return_value=None) - model.get_input_embeddings().register_forward_hook(hook) + hooks = [] + base_model = model.model if hasattr(model, "model") else model + + if hasattr(base_model, "encoder") and hasattr(base_model.encoder, "embed_tokens"): + hook = unittest.mock.MagicMock(return_value=None) + base_model.encoder.embed_tokens.register_forward_hook(hook) + hooks.append(hook) + if hasattr(base_model, "decoder") and hasattr(base_model.decoder, "embed_tokens"): + hook = unittest.mock.MagicMock(return_value=None) + base_model.decoder.embed_tokens.register_forward_hook(hook) + hooks.append(hook) inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) model(**inputs) - self.assertGreater(hook.call_count, 0) + total_calls = sum(hook.call_count for hook in hooks) + self.assertGreater(total_calls, 0, f"Hooks on embeddings were not called for {model_class.__name__}") @require_torch_fp16 def test_generate_fp16(self): From 1cb35ff8bd5a87fb240f2e04cb2bf45b01dd27f3 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo Date: Thu, 27 Nov 2025 16:10:47 +0100 Subject: [PATCH 7/7] lingering fails --- src/transformers/models/blip_2/modeling_blip_2.py | 6 ++++-- .../models/qwen3_omni_moe/modeling_qwen3_omni_moe.py | 6 +++--- .../models/qwen3_omni_moe/modular_qwen3_omni_moe.py | 6 ++++++ .../models/timm_wrapper/modeling_timm_wrapper.py | 7 +++++++ 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index d38112f1fbf8..e133058d794e 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -888,10 +888,12 @@ def __init__(self, config: Blip2QFormerConfig): self.post_init() def get_input_embeddings(self): - return self.embeddings.word_embeddings + # The Q-Former operates on embeddings provided by upstream modules (e.g. query tokens or text embeddings). + # It does not own input embeddings itself, so we return `None` to signal that there is nothing to update. + return None def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value + raise NotImplementedError("Blip2QFormerModel does not own input embeddings and cannot set them.") def get_extended_attention_mask( self, diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 69b6be3f4fcc..fea9ba4dea2a 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -686,10 +686,10 @@ def _freeze_parameters(self): self._requires_grad = False def get_input_embeddings(self) -> nn.Module: - return self.conv1 + return self.conv2d1 - def set_input_embeddings(self, value: nn.Module): - self.conv1 = value + def set_input_embeddings(self, value): + self.conv2d1 = value def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen` diff --git a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py index cd9f94681b9d..b0597d50d5c9 100644 --- a/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py @@ -1094,6 +1094,12 @@ def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig): self.n_window_infer = self.config.n_window_infer self.conv_chunksize = self.config.conv_chunksize + def get_input_embeddings(self): + return self.conv2d1 + + def set_input_embeddings(self, value): + self.conv2d1 = value + def forward( self, input_features, diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index aa541e0b15fb..26688eed2db5 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -168,6 +168,13 @@ def __init__(self, config: TimmWrapperConfig): self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs) self.post_init() + def get_input_embeddings(self): + # Vision backbones from timm do not expose token embeddings, so there is nothing to return. + return None + + def set_input_embeddings(self, value): + raise NotImplementedError("TimmWrapperModel does not own token embeddings and cannot set them.") + @auto_docstring def forward( self,