-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Attempt to fix VLM gradient enabling #41993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
61679f8
83c4c0e
c1f0329
3e160ed
d775535
b70f0e1
f8e7827
bd2f811
17c0db4
a779501
5b07e43
4db9403
b1a79e3
6c39af2
1cb35ff
f57834b
8b47a00
1552922
d69cdb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1976,13 +1976,44 @@ def enable_input_require_grads(self): | |
| def make_inputs_require_grads(module, input, output): | ||
| output.requires_grad_(True) | ||
|
|
||
| self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) | ||
| hooks = [] | ||
| seen_modules = set() | ||
|
|
||
| for module in self.modules(): | ||
| if not (isinstance(module, PreTrainedModel) and hasattr(module, "get_input_embeddings")): | ||
| continue | ||
|
|
||
| input_embeddings = module.get_input_embeddings() | ||
|
|
||
| if input_embeddings is None: | ||
| continue | ||
|
|
||
| embedding_id = id(input_embeddings) | ||
| if embedding_id in seen_modules: | ||
| continue | ||
|
|
||
| seen_modules.add(embedding_id) | ||
| hooks.append(input_embeddings.register_forward_hook(make_inputs_require_grads)) | ||
|
|
||
| self._require_grads_hooks = hooks | ||
| if hooks: | ||
| # for BC | ||
| self._require_grads_hook = hooks[0] | ||
|
Comment on lines
+1999
to
+2001
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aren't we ignoring all hooks except for the first one in this case, i.e. when we disable it will disable the text model and will not disable vision model?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so, this is just because we used to remove
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ahh my bad, didn't see the "s" at the end
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might be a bad naming then haha |
||
|
|
||
| def disable_input_require_grads(self): | ||
| """ | ||
| Removes the `_require_grads_hook`. | ||
| """ | ||
| self._require_grads_hook.remove() | ||
| hooks = getattr(self, "_require_grads_hooks", None) | ||
| if not hooks: | ||
| return | ||
|
|
||
| for hook in hooks: | ||
| hook.remove() | ||
|
|
||
| self._require_grads_hooks = [] | ||
| if hasattr(self, "_require_grads_hook"): | ||
| del self._require_grads_hook | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out of curiosity, is it required to explicitly delete?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just out of safety, not certain it's always necessary but not knowing what people were doing with that hook in their FT scripts I think it's safer to remove it so no reference remains |
||
|
|
||
| def get_encoder(self, modality: Optional[str] = None): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -787,36 +787,6 @@ def __init__(self, config: Idefics2Config): | |
|
|
||
| self.post_init() | ||
|
|
||
| def enable_input_require_grads(self): | ||
| """ | ||
| Enables the gradients for the input embeddings. | ||
|
|
||
| This is useful for lora when using gradient checkpointing. | ||
| c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032 | ||
|
Comment on lines
-790
to
-795
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice 🔪 |
||
|
|
||
| Override to set output.requires_grad = True for both the decoder's and vision model's embeddings. | ||
| """ | ||
|
|
||
| def get_lowest_module(module): | ||
| if len(list(module.children())) == 0: | ||
| # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.) | ||
| return module | ||
| else: | ||
| # Recursively call the function on each child module | ||
| return get_lowest_module(list(module.children())[0]) | ||
|
|
||
| def make_inputs_require_grads(module, input, output): | ||
| output.requires_grad_(True) | ||
|
|
||
| self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) | ||
| self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook( | ||
| make_inputs_require_grads | ||
| ) | ||
|
|
||
| def disable_input_require_grads(self): | ||
| self._text_require_grads_hook.remove() | ||
| self._vision_require_grads_hook.remove() | ||
|
|
||
| def get_input_embeddings(self): | ||
| return self.text_model.get_input_embeddings() | ||
|
|
||
|
|
@@ -1009,24 +979,6 @@ def __init__(self, config): | |
| # Initialize weights and apply final processing | ||
| self.post_init() | ||
|
|
||
| def enable_input_require_grads(self): | ||
| """ | ||
| Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping | ||
| the model weights fixed. | ||
| """ | ||
|
|
||
| def make_inputs_require_grads(module, input, output): | ||
| output.requires_grad_(True) | ||
|
|
||
| self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) | ||
| self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook( | ||
| make_inputs_require_grads | ||
| ) | ||
|
|
||
| def disable_input_require_grads(self): | ||
| self._text_require_grads_hook.remove() | ||
| self._vision_require_grads_hook.remove() | ||
|
|
||
| def get_input_embeddings(self): | ||
| return self.model.text_model.get_input_embeddings() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super clean!