From 1e17ea4bd7ad32c5e9b540ec0f503dd1db8b6abf Mon Sep 17 00:00:00 2001 From: Eric B Date: Tue, 25 Nov 2025 08:01:55 +0100 Subject: [PATCH 1/2] Remove unnecessary tied weights? --- src/transformers/models/seamless_m4t/modeling_seamless_m4t.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 13ebc98ef56d..885acd90f02c 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1981,7 +1981,6 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, "text_encoder", "text_decoder", ] - _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"} def __init__( self, From ed8b0a60436c1627fc4e25c3346603320d96fe6d Mon Sep 17 00:00:00 2001 From: vasqu Date: Tue, 25 Nov 2025 18:33:17 +0100 Subject: [PATCH 2/2] fix --- .../seamless_m4t/modeling_seamless_m4t.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 885acd90f02c..4d221e3b5469 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1981,6 +1981,7 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, "text_encoder", "text_decoder", ] + _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"} def __init__( self, @@ -2005,6 +2006,19 @@ def __init__( # Initialize weights and apply final processing self.post_init() + def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True): + """We need to overload here to handle the wrong key saved in some main checkpoints.""" + if self.config.tie_word_embeddings: + # Some model checkpoints like "facebook/hf-seamless-m4t-medium"'s embedding weight is in decoder.embed_tokens, + # need check here + if missing_keys is not None: + if "lm_head.weight" in missing_keys and "model.decoder.embed_tokens.weight" not in missing_keys: + self.lm_head.weight = self.decoder.embed_tokens.weight + missing_keys.discard("lm_head.weight") + + # needs to be done after, otherwise it raises an Error because the correct weights are not present + super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping) + def get_encoder(self): return self.model.encoder @@ -3613,6 +3627,19 @@ def __init__(self, config, current_modality="text"): # Initialize weights and apply final processing self.post_init() + def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True): + """We need to overload here to handle the wrong key saved in some main checkpoints.""" + if self.config.tie_word_embeddings: + # Some model checkpoints like "facebook/hf-seamless-m4t-medium"'s embedding weight is in decoder.embed_tokens, + # need check here + if missing_keys is not None: + if "t2u_model.model.decoder.embed_tokens.weight" in missing_keys and "t2u_model.lm_head.weight" not in missing_keys: + self.t2u_model.model.decoder.embed_tokens.weight = self.t2u_model.lm_head.weight + missing_keys.discard("t2u_model.model.decoder.embed_tokens.weight") + + # needs to be done after, otherwise it raises an Error because the correct weights are not present + super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping) + def set_modality(self, modality="text"): if modality == "text": self.main_input_name = "input_ids"