Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1976,13 +1976,44 @@ def enable_input_require_grads(self):
def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)

self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
hooks = []
seen_modules = set()

for module in self.modules():
if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")):
continue

input_embeddings = module.get_input_embeddings()

if input_embeddings is None:
continue

embedding_id = id(input_embeddings)
if embedding_id in seen_modules:
continue

seen_modules.add(embedding_id)
hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads))

Comment on lines +1982 to +1997
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super clean!

self._require_grads_hooks = hooks
if hooks:
# for BC
self._require_grads_hook = hooks[0]
Comment on lines +1999 to +2001
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't we ignoring all hooks except for the first one in this case, i.e. when we disable it will disable the text model and will not disable vision model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, this is just because we used to remove _require_grads_hook, now we always iterate over the full list _require_grads_hooks (with an s) so every registered hook (vision or text or whatever) should be removed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh my bad, didn't see the "s" at the end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be a bad naming then haha


def disable_input_require_grads(self):
"""
Removes the `_require_grads_hook`.
"""
self._require_grads_hook.remove()
hooks = getattr(self, "_require_grads_hooks", None)
if not hooks:
return

for hook in hooks:
hook.remove()

self._require_grads_hooks = []
if hasattr(self, "_require_grads_hook"):
del self._require_grads_hook
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, is it required to explicitly delete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of safety, not certain it's always necessary but not knowing what people were doing with that hook in their FT scripts I think it's safer to remove it so no reference remains


def get_encoder(self, modality: Optional[str] = None):
"""
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,10 +888,12 @@ def __init__(self, config: Blip2QFormerConfig):
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings
# The Q-Former operates on embeddings provided by upstream modules (e.g. query tokens or text embeddings).
# It does not own input embeddings itself, so we return `None` to signal that there is nothing to update.
return None

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
raise NotImplementedError("Blip2QFormerModel does not own input embeddings and cannot set them.")

def get_extended_attention_mask(
self,
Expand Down
48 changes: 0 additions & 48 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,36 +787,6 @@ def __init__(self, config: Idefics2Config):

self.post_init()

def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings.

This is useful for lora when using gradient checkpointing.
c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
Comment on lines -790 to -795
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice 🔪


Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
"""

def get_lowest_module(module):
if len(list(module.children())) == 0:
# If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
return module
else:
# Recursively call the function on each child module
return get_lowest_module(list(module.children())[0])

def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)

self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
make_inputs_require_grads
)

def disable_input_require_grads(self):
self._text_require_grads_hook.remove()
self._vision_require_grads_hook.remove()

def get_input_embeddings(self):
return self.text_model.get_input_embeddings()

Expand Down Expand Up @@ -1009,24 +979,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
the model weights fixed.
"""

def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)

self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
make_inputs_require_grads
)

def disable_input_require_grads(self):
self._text_require_grads_hook.remove()
self._vision_require_grads_hook.remove()

def get_input_embeddings(self):
return self.model.text_model.get_input_embeddings()

Expand Down
52 changes: 0 additions & 52 deletions src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,38 +530,6 @@ def __init__(self, config: Idefics3Config):

self.post_init()

# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads
def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings.

This is useful for lora when using gradient checkpointing.
c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032

Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
"""

def get_lowest_module(module):
if len(list(module.children())) == 0:
# If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
return module
else:
# Recursively call the function on each child module
return get_lowest_module(list(module.children())[0])

def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)

self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
make_inputs_require_grads
)

# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads
def disable_input_require_grads(self):
self._text_require_grads_hook.remove()
self._vision_require_grads_hook.remove()

# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings
def get_input_embeddings(self):
return self.text_model.get_input_embeddings()
Expand Down Expand Up @@ -765,26 +733,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.enable_input_require_grads
def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
the model weights fixed.
"""

def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)

self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
make_inputs_require_grads
)

# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads
def disable_input_require_grads(self):
self._text_require_grads_hook.remove()
self._vision_require_grads_hook.remove()

# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.model.text_model.get_input_embeddings()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
config: Qwen2_5OmniVisionEncoderConfig
_no_split_modules = ["Qwen2_5OmniVisionBlock"]
_input_embed_layer = "patch_embed"
input_modalities = ("image", "video")

def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,7 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel):
config: Qwen2_5OmniVisionEncoderConfig
input_modalities = ("image", "video")
_no_split_modules = ["Qwen2_5OmniVisionBlock"]
_input_embed_layer = "patch_embed"

def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
config: Qwen2_5_VLVisionConfig
_no_split_modules = ["Qwen2_5_VLVisionBlock"]
_input_embed_layer = "patch_embed"

def __init__(self, config, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class Qwen2_5_VLPreTrainedModel(Qwen2VLPreTrainedModel):
class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
config: Qwen2_5_VLVisionConfig
_no_split_modules = ["Qwen2_5_VLVisionBlock"]
_input_embed_layer = "patch_embed"

def __init__(self, config, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
config: Qwen2VLVisionConfig
input_modalities = ("image", "video")
_no_split_modules = ["Qwen2VLVisionBlock"]
_input_embed_layer = "patch_embed"

def __init__(self, config) -> None:
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -686,10 +686,10 @@ def _freeze_parameters(self):
self._requires_grad = False

def get_input_embeddings(self) -> nn.Module:
return self.conv1
return self.conv2d1

def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def set_input_embeddings(self, value):
self.conv2d1 = value

def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,12 @@ def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig):
self.n_window_infer = self.config.n_window_infer
self.conv_chunksize = self.config.conv_chunksize

def get_input_embeddings(self):
return self.conv2d1

def set_input_embeddings(self, value):
self.conv2d1 = value

def forward(
self,
input_features,
Expand Down
48 changes: 0 additions & 48 deletions src/transformers/models/smolvlm/modeling_smolvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,36 +470,6 @@ def __init__(self, config: SmolVLMConfig):

self.post_init()

def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings.

This is useful for lora when using gradient checkpointing.
c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032

Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
"""

def get_lowest_module(module):
if len(list(module.children())) == 0:
# If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
return module
else:
# Recursively call the function on each child module
return get_lowest_module(list(module.children())[0])

def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)

self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
make_inputs_require_grads
)

def disable_input_require_grads(self):
self._text_require_grads_hook.remove()
self._vision_require_grads_hook.remove()

def get_input_embeddings(self):
return self.text_model.get_input_embeddings()

Expand Down Expand Up @@ -748,24 +718,6 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
the model weights fixed.
"""

def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)

self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
make_inputs_require_grads
)

def disable_input_require_grads(self):
self._text_require_grads_hook.remove()
self._vision_require_grads_hook.remove()

def get_input_embeddings(self):
return self.model.text_model.get_input_embeddings()

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ def __init__(self, config: TimmWrapperConfig):
self.timm_model = _create_timm_model_with_error_handling(config, num_classes=0, **extra_init_kwargs)
self.post_init()

def get_input_embeddings(self):
# Vision backbones from timm do not expose token embeddings, so there is nothing to return.
return None

def set_input_embeddings(self, value):
raise NotImplementedError("TimmWrapperModel does not own token embeddings and cannot set them.")

@auto_docstring
def forward(
self,
Expand Down
19 changes: 15 additions & 4 deletions tests/models/bart/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,23 +478,34 @@ def test_inputs_embeds(self):
with torch.no_grad():
model(**inputs)[0]

@unittest.skip("Bart no longer always uses self.shared so not working.")
def test_input_embeddings_support_forward_hook(self):
# Make sure that registering hooks on the input embeddings are indeed called
# in forward. This is necessary for gradient checkpointing in PEFT, see also #41821.
# For BART with tied embeddings, encoder and decoder have separate embedding modules,
# so we need to check that hooks on those modules are called during forward.
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()

hook = unittest.mock.MagicMock(return_value=None)
model.get_input_embeddings().register_forward_hook(hook)
hooks = []
base_model = model.model if hasattr(model, "model") else model

if hasattr(base_model, "encoder") and hasattr(base_model.encoder, "embed_tokens"):
hook = unittest.mock.MagicMock(return_value=None)
base_model.encoder.embed_tokens.register_forward_hook(hook)
hooks.append(hook)
if hasattr(base_model, "decoder") and hasattr(base_model.decoder, "embed_tokens"):
hook = unittest.mock.MagicMock(return_value=None)
base_model.decoder.embed_tokens.register_forward_hook(hook)
hooks.append(hook)

inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
model(**inputs)

self.assertGreater(hook.call_count, 0)
total_calls = sum(hook.call_count for hook in hooks)
self.assertGreater(total_calls, 0, f"Hooks on embeddings were not called for {model_class.__name__}")

@require_torch_fp16
def test_generate_fp16(self):
Expand Down
Loading