diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 13ebc98ef56d..4d221e3b5469 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -2006,6 +2006,19 @@ def __init__( # Initialize weights and apply final processing self.post_init() + def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True): + """We need to overload here to handle the wrong key saved in some main checkpoints.""" + if self.config.tie_word_embeddings: + # Some model checkpoints like "facebook/hf-seamless-m4t-medium"'s embedding weight is in decoder.embed_tokens, + # need check here + if missing_keys is not None: + if "lm_head.weight" in missing_keys and "model.decoder.embed_tokens.weight" not in missing_keys: + self.lm_head.weight = self.decoder.embed_tokens.weight + missing_keys.discard("lm_head.weight") + + # needs to be done after, otherwise it raises an Error because the correct weights are not present + super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping) + def get_encoder(self): return self.model.encoder @@ -3614,6 +3627,19 @@ def __init__(self, config, current_modality="text"): # Initialize weights and apply final processing self.post_init() + def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True): + """We need to overload here to handle the wrong key saved in some main checkpoints.""" + if self.config.tie_word_embeddings: + # Some model checkpoints like "facebook/hf-seamless-m4t-medium"'s embedding weight is in decoder.embed_tokens, + # need check here + if missing_keys is not None: + if "t2u_model.model.decoder.embed_tokens.weight" in missing_keys and "t2u_model.lm_head.weight" not in missing_keys: + self.t2u_model.model.decoder.embed_tokens.weight = self.t2u_model.lm_head.weight + missing_keys.discard("t2u_model.model.decoder.embed_tokens.weight") + + # needs to be done after, otherwise it raises an Error because the correct weights are not present + super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping) + def set_modality(self, modality="text"): if modality == "text": self.main_input_name = "input_ids"