-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Fix weight tying logic between _tied_weights_keys and tie_word_embeddings
#42385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
47b0e0e
70fb402
6ce2bc6
60665c6
779012a
aa7ab80
afa33c5
22a258f
076f20e
58ec6e6
846bb69
4abe6d7
5a05c97
9c81dfc
e22acad
f939769
9960aa2
24f0b09
d2a03d2
8c4fb33
c8d8ad0
589e89c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2254,8 +2254,17 @@ def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict: | |
| return expanded_tied_weights | ||
|
|
||
| tied_mapping = self._tied_weights_keys | ||
| text_config = self.config.get_text_config(decoder=True) | ||
| if not hasattr(text_config, "tie_word_embeddings"): | ||
| logger.warning( | ||
| f"Text config {text_config.__class__.__name__} does not have 'tie_word_embeddings' attribute. " | ||
| "This may cause issues with weight tying." | ||
| ) | ||
| tie_word_embeddings = getattr(text_config, "tie_word_embeddings", None) | ||
|
Comment on lines
+2257
to
+2263
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think all multimodals rely on the text_config here unfortunately, do they? I.e. what we want is that each model decide on its own if it should tie its own weights. So we don't really want to check the text_config...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, it's the opposite of what we discussed above. In that case, what do you suggest?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well the issue is that if you have a submodel such as |
||
| tie_encoder_decoder = getattr(self.config, "tie_encoder_decoder", False) | ||
| should_tie = tie_encoder_decoder if tie_word_embeddings is None else tie_word_embeddings | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we basically disregard
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I'm reverting this unfortunately, I would indeed like to get rid of |
||
| # If the config does not specify any tying, return empty dict | ||
| if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder: | ||
| if not should_tie: | ||
| return {} | ||
| # If None, return empty dict | ||
| elif tied_mapping is None: | ||
|
|
@@ -3178,7 +3187,11 @@ def save_pretrained( | |
| shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} | ||
|
|
||
| # Recursively descend to find tied weight keys | ||
| _tied_weights_keys = set(_get_tied_weight_keys(self)) | ||
| tied_keys_attr = getattr(self, "all_tied_weights_keys", None) | ||
| if tied_keys_attr is not None: | ||
| _tied_weights_keys = set(tied_keys_attr.keys()) | ||
| else: | ||
| _tied_weights_keys = set(_get_tied_weight_keys(self)) | ||
|
Comment on lines
+3190
to
+3194
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the main change. IMO it's why
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, i think we fill it in in many cases just to pass some tests even if the released models don't tie weights
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I am a bit lost here. Do we snow try to tie weights with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it's what I observed. Tried to get a reproducer here @zucchini-nlp @ArthurZucker from transformers import PaliGemmaConfig, PaliGemmaForConditionalGeneration
from transformers.models.gemma.configuration_gemma import GemmaConfig
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
vision_config = SiglipVisionConfig(
hidden_size=32,
intermediate_size=64,
num_hidden_layers=1,
num_attention_heads=2,
image_size=32,
patch_size=4,
vocab_size=64,
projection_dim=32,
).to_dict()
text_config = GemmaConfig(
vocab_size=64,
hidden_size=32,
intermediate_size=64,
num_hidden_layers=1,
num_attention_heads=4,
num_key_value_heads=4,
tie_word_embeddings=False,
).to_dict()
config = PaliGemmaConfig(vision_config=vision_config, text_config=text_config)
print("Config tie flags:")
print(f" config.tie_word_embeddings -> {config.tie_word_embeddings}")
print(f" config.text_config.tie_word_embeddings -> {config.text_config.tie_word_embeddings}")
model = PaliGemmaForConditionalGeneration(config)
lm_head_weight = model.lm_head.weight
input_embed_weight = model.model.language_model.embed_tokens.weight
tied = lm_head_weight.data_ptr() == input_embed_weight.data_ptr()
print("\nModel tying details:")
print(f" class _tied_weights_keys -> {model._tied_weights_keys}")
print(f" model.all_tied_weights_keys -> {model.all_tied_weights_keys}")
print(f" lm_head shares embedding tensor? -> {tied}")this will show tied weights. However if you go to modeling_paligemma and change _tied_weights_keys = {}# {"lm_head.weight": "model.language_model.embed_tokens.weight"}then it will work and weights will not be tied. I think the solution is to use and for instance for There might be a simpler fix, not seeing it right now
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fact here is that I think some older models were tying using So technically looking at all potential
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Hmm ok, but then why are models like mistral3 broken by the current setup? Ah because you mean we did it in Not super sure of the way forward there, can you elaborate a bit? |
||
| error_names = [] | ||
| to_delete_names = set() | ||
| for names in shared_ptrs.values(): | ||
|
|
@@ -4410,7 +4423,9 @@ def _move_missing_keys_from_meta_to_cpu( | |
| # The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway) | ||
| # This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they | ||
| # will be re-initialized for nothing (which can be quite long) | ||
| for key in missing_keys - self.all_tied_weights_keys.keys(): | ||
| tied_keys_attr = getattr(self, "all_tied_weights_keys", {}) or {} | ||
| tied_keys = set(tied_keys_attr.keys()) | ||
| for key in missing_keys - tied_keys: | ||
|
Comment on lines
-4413
to
+4428
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yes indeed, could be just tied_keys = set(self.all_tied_weights_keys.keys())
for key in missing_keys - tied_keys:
do stuffwill update
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to cast it again as a |
||
| param = model_state_dict[key] | ||
| # Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them | ||
| if param.device == torch.device("meta"): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2100,6 +2100,49 @@ def test_tied_weights_keys(self): | |
| f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.", | ||
| ) | ||
|
|
||
| def test_tie_word_embeddings_is_authoritative(self): | ||
| original_config, _ = self.model_tester.prepare_config_and_inputs_for_common() | ||
|
|
||
| for model_class in self.all_model_classes: | ||
| tied_config = copy.deepcopy(original_config) | ||
| tied_config.get_text_config().tie_word_embeddings = True | ||
|
|
||
| untied_config = copy.deepcopy(original_config) | ||
| untied_config.get_text_config().tie_word_embeddings = False | ||
|
|
||
| model_tied = model_class(tied_config) | ||
| model_untied = model_class(untied_config) | ||
|
|
||
| if not hasattr(model_tied, "_tied_weights_keys") or not model_tied._tied_weights_keys: | ||
| continue | ||
|
Comment on lines
+2116
to
+2117
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note from huddle: should recurse through all submodels here to get the tied keys |
||
|
|
||
| tied_keys = model_tied._tied_weights_keys | ||
| state_dict_tied = model_tied.state_dict() | ||
| state_dict_untied = model_untied.state_dict() | ||
|
|
||
| for target_key, source_key in tied_keys.items(): | ||
| if target_key not in state_dict_tied or source_key not in state_dict_tied: | ||
| continue | ||
| if target_key not in state_dict_untied or source_key not in state_dict_untied: | ||
| continue | ||
|
|
||
| target_tied_ptr = id_tensor_storage(state_dict_tied[target_key]) | ||
| source_tied_ptr = id_tensor_storage(state_dict_tied[source_key]) | ||
| target_untied_ptr = id_tensor_storage(state_dict_untied[target_key]) | ||
| source_untied_ptr = id_tensor_storage(state_dict_untied[source_key]) | ||
|
|
||
| self.assertEqual( | ||
| target_tied_ptr, | ||
| source_tied_ptr, | ||
| f"{model_class}: With tie_word_embeddings=True, '{target_key}' should share storage with '{source_key}'", | ||
| ) | ||
| self.assertNotEqual( | ||
| target_untied_ptr, | ||
| source_untied_ptr, | ||
| f"{model_class}: With tie_word_embeddings=False, '{target_key}' should NOT share storage with '{source_key}'. " | ||
| f"Config tie_word_embeddings must be authoritative over class-level _tied_weights_keys.", | ||
| ) | ||
|
|
||
| def test_model_weights_reload_no_missing_tied_weights(self): | ||
| for model_class in self.all_model_classes: | ||
| config, _ = self.model_tester.prepare_config_and_inputs_for_common() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, was completely masked by the
log_to_misc, along with other issues we have currently on some models. We will need to update the logger to avoid silencing important issues!!