Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
Original file line number Diff line number Diff line change
Expand Up @@ -2006,6 +2006,19 @@ def __init__(
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True):
"""We need to overload here to handle the wrong key saved in some main checkpoints."""
if self.config.tie_word_embeddings:
# Some model checkpoints like "facebook/hf-seamless-m4t-medium"'s embedding weight is in decoder.embed_tokens,
# need check here
if missing_keys is not None:
if "lm_head.weight" in missing_keys and "model.decoder.embed_tokens.weight" not in missing_keys:
self.lm_head.weight = self.decoder.embed_tokens.weight
missing_keys.discard("lm_head.weight")

# needs to be done after, otherwise it raises an Error because the correct weights are not present
super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping)
Comment on lines +2009 to +2020
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep this makes sense, tho I think the explicit mapping can work in revers basically. #42362.
It affects more models!


def get_encoder(self):
return self.model.encoder

Expand Down Expand Up @@ -3614,6 +3627,19 @@ def __init__(self, config, current_modality="text"):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True):
"""We need to overload here to handle the wrong key saved in some main checkpoints."""
if self.config.tie_word_embeddings:
# Some model checkpoints like "facebook/hf-seamless-m4t-medium"'s embedding weight is in decoder.embed_tokens,
# need check here
if missing_keys is not None:
if "t2u_model.model.decoder.embed_tokens.weight" in missing_keys and "t2u_model.lm_head.weight" not in missing_keys:
self.t2u_model.model.decoder.embed_tokens.weight = self.t2u_model.lm_head.weight
missing_keys.discard("t2u_model.model.decoder.embed_tokens.weight")

# needs to be done after, otherwise it raises an Error because the correct weights are not present
super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping)

def set_modality(self, modality="text"):
if modality == "text":
self.main_input_name = "input_ids"
Expand Down