Skip to content

Commit e2f08ea

Browse files
molbapArthurZucker
andauthored
Attempt to fix VLM gradient enabling (#41993)
* attempt to fix gradients * Improve tests, use PreTrainedModel hooks, cleanup * missing patch_embed * fix arg name * local revert * adapt BART test * lingering fails --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent afd039d commit e2f08ea

File tree

15 files changed

+129
-159
lines changed

15 files changed

+129
-159
lines changed

src/transformers/modeling_utils.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1976,13 +1976,44 @@ def enable_input_require_grads(self):
19761976
def make_inputs_require_grads(module, input, output):
19771977
output.requires_grad_(True)
19781978

1979-
self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
1979+
hooks = []
1980+
seen_modules = set()
1981+
1982+
for module in self.modules():
1983+
if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
1984+
continue
1985+
1986+
input_embeddings = module.get_input_embeddings()
1987+
1988+
if input_embeddings is None:
1989+
continue
1990+
1991+
embedding_id = id(input_embeddings)
1992+
if embedding_id in seen_modules:
1993+
continue
1994+
1995+
seen_modules.add(embedding_id)
1996+
hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))
1997+
1998+
self._require_grads_hooks = hooks
1999+
if hooks:
2000+
# for BC
2001+
self._require_grads_hook = hooks[0]
19802002

19812003
def disable_input_require_grads(self):
19822004
"""
19832005
Removes the `_require_grads_hook`.
19842006
"""
1985-
self._require_grads_hook.remove()
2007+
hooks = getattr(self, "_require_grads_hooks", None)
2008+
if not hooks:
2009+
return
2010+
2011+
for hook in hooks:
2012+
hook.remove()
2013+
2014+
self._require_grads_hooks = []
2015+
if hasattr(self, "_require_grads_hook"):
2016+
del self._require_grads_hook
19862017

19872018
def get_encoder(self, modality: Optional[str] = None):
19882019
"""

src/transformers/models/blip_2/modeling_blip_2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -888,10 +888,12 @@ def __init__(self, config: Blip2QFormerConfig):
888888
self.post_init()
889889

890890
def get_input_embeddings(self):
891-
return self.embeddings.word_embeddings
891+
# The Q-Former operates on embeddings provided by upstream modules (e.g. query tokens or text embeddings).
892+
# It does not own input embeddings itself, so we return `None` to signal that there is nothing to update.
893+
return None
892894

893895
def set_input_embeddings(self, value):
894-
self.embeddings.word_embeddings = value
896+
raise NotImplementedError("Blip2QFormerModel does not own input embeddings and cannot set them.")
895897

896898
def get_extended_attention_mask(
897899
self,

src/transformers/models/idefics2/modeling_idefics2.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -787,36 +787,6 @@ def __init__(self, config: Idefics2Config):
787787

788788
self.post_init()
789789

790-
def enable_input_require_grads(self):
791-
"""
792-
Enables the gradients for the input embeddings.
793-
794-
This is useful for lora when using gradient checkpointing.
795-
c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
796-
797-
Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
798-
"""
799-
800-
def get_lowest_module(module):
801-
if len(list(module.children())) == 0:
802-
# If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
803-
return module
804-
else:
805-
# Recursively call the function on each child module
806-
return get_lowest_module(list(module.children())[0])
807-
808-
def make_inputs_require_grads(module, input, output):
809-
output.requires_grad_(True)
810-
811-
self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
812-
self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
813-
make_inputs_require_grads
814-
)
815-
816-
def disable_input_require_grads(self):
817-
self._text_require_grads_hook.remove()
818-
self._vision_require_grads_hook.remove()
819-
820790
def get_input_embeddings(self):
821791
return self.text_model.get_input_embeddings()
822792

@@ -1009,24 +979,6 @@ def __init__(self, config):
1009979
# Initialize weights and apply final processing
1010980
self.post_init()
1011981

1012-
def enable_input_require_grads(self):
1013-
"""
1014-
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
1015-
the model weights fixed.
1016-
"""
1017-
1018-
def make_inputs_require_grads(module, input, output):
1019-
output.requires_grad_(True)
1020-
1021-
self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
1022-
self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
1023-
make_inputs_require_grads
1024-
)
1025-
1026-
def disable_input_require_grads(self):
1027-
self._text_require_grads_hook.remove()
1028-
self._vision_require_grads_hook.remove()
1029-
1030982
def get_input_embeddings(self):
1031983
return self.model.text_model.get_input_embeddings()
1032984

src/transformers/models/idefics3/modeling_idefics3.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -530,38 +530,6 @@ def __init__(self, config: Idefics3Config):
530530

531531
self.post_init()
532532

533-
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads
534-
def enable_input_require_grads(self):
535-
"""
536-
Enables the gradients for the input embeddings.
537-
538-
This is useful for lora when using gradient checkpointing.
539-
c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
540-
541-
Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
542-
"""
543-
544-
def get_lowest_module(module):
545-
if len(list(module.children())) == 0:
546-
# If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
547-
return module
548-
else:
549-
# Recursively call the function on each child module
550-
return get_lowest_module(list(module.children())[0])
551-
552-
def make_inputs_require_grads(module, input, output):
553-
output.requires_grad_(True)
554-
555-
self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
556-
self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
557-
make_inputs_require_grads
558-
)
559-
560-
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads
561-
def disable_input_require_grads(self):
562-
self._text_require_grads_hook.remove()
563-
self._vision_require_grads_hook.remove()
564-
565533
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings
566534
def get_input_embeddings(self):
567535
return self.text_model.get_input_embeddings()
@@ -765,26 +733,6 @@ def __init__(self, config):
765733
# Initialize weights and apply final processing
766734
self.post_init()
767735

768-
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.enable_input_require_grads
769-
def enable_input_require_grads(self):
770-
"""
771-
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
772-
the model weights fixed.
773-
"""
774-
775-
def make_inputs_require_grads(module, input, output):
776-
output.requires_grad_(True)
777-
778-
self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
779-
self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
780-
make_inputs_require_grads
781-
)
782-
783-
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads
784-
def disable_input_require_grads(self):
785-
self._text_require_grads_hook.remove()
786-
self._vision_require_grads_hook.remove()
787-
788736
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings
789737
def get_input_embeddings(self):
790738
return self.model.text_model.get_input_embeddings()

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
10761076
class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
10771077
config: Qwen2_5OmniVisionEncoderConfig
10781078
_no_split_modules = ["Qwen2_5OmniVisionBlock"]
1079+
_input_embed_layer = "patch_embed"
10791080
input_modalities = ("image", "video")
10801081

10811082
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None:

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,6 +1920,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
19201920
config: Qwen2_5OmniVisionEncoderConfig
19211921
input_modalities = ("image", "video")
19221922
_no_split_modules = ["Qwen2_5OmniVisionBlock"]
1923+
_input_embed_layer = "patch_embed"
19231924

19241925
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None:
19251926
super().__init__(config, *inputs, **kwargs)

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
307307
class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
308308
config: Qwen2_5_VLVisionConfig
309309
_no_split_modules = ["Qwen2_5_VLVisionBlock"]
310+
_input_embed_layer = "patch_embed"
310311

311312
def __init__(self, config, *inputs, **kwargs) -> None:
312313
super().__init__(config, *inputs, **kwargs)

src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel):
179179
class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
180180
config: Qwen2_5_VLVisionConfig
181181
_no_split_modules = ["Qwen2_5_VLVisionBlock"]
182+
_input_embed_layer = "patch_embed"
182183

183184
def __init__(self, config, *inputs, **kwargs) -> None:
184185
super().__init__(config, *inputs, **kwargs)

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
671671
config: Qwen2VLVisionConfig
672672
input_modalities = ("image", "video")
673673
_no_split_modules = ["Qwen2VLVisionBlock"]
674+
_input_embed_layer = "patch_embed"
674675

675676
def __init__(self, config) -> None:
676677
super().__init__(config)

src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,10 +686,10 @@ def _freeze_parameters(self):
686686
self._requires_grad = False
687687

688688
def get_input_embeddings(self) -> nn.Module:
689-
return self.conv1
689+
return self.conv2d1
690690

691-
def set_input_embeddings(self, value: nn.Module):
692-
self.conv1 = value
691+
def set_input_embeddings(self, value):
692+
self.conv2d1 = value
693693

694694
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
695695
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`

0 commit comments

Comments
 (0)