Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,10 @@ def from_dict(

config = cls(**config_dict)

# default tie_word_embeddings to False if None, see https://github.com/huggingface/transformers/issues/42313
if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is None:
config.tie_word_embeddings = False

# Update config with kwargs if needed
if "num_labels" in kwargs and "id2label" in kwargs:
num_labels = kwargs["num_labels"]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _build_checkpoint_conversion_mapping():
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
mapping["olmoe"] = mapping["qwen2_moe"].copy()
mapping["minimax"] = mapping["mixtral"].copy()

return mapping
Expand Down
1 change: 0 additions & 1 deletion src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def set_param_for_module(
missing_keys.discard(target_name)
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
mismatch_keys.add((target_name, param_value.shape, ref.shape))
module_obj.param_name._is_hf_initialized = False # Needs to be initialized
Comment on lines 506 to -507
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, was completely masked by the log_to_misc, along with other issues we have currently on some models. We will need to update the logger to avoid silencing important issues!!

else:
# super important otherwise _init_weight will re-init the param
param_value._is_hf_initialized = True
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,12 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
mapping[kernel_name] = kernel
except FileNotFoundError:
mapping[kernel_name] = None
except AssertionError as error:
logger.warning_once(
f"Failed to load the '{kernel_name}' kernel from '{repo_id}' because the current environment does not "
f"support the required backend: {error}"
)
mapping[kernel_name] = None

else:
# Try to import is_{kernel_name}_available from ..utils
Expand Down
21 changes: 18 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,8 +2254,17 @@ def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict:
return expanded_tied_weights

tied_mapping = self._tied_weights_keys
text_config = self.config.get_text_config(decoder=True)
if not hasattr(text_config, "tie_word_embeddings"):
logger.warning(
f"Text config {text_config.__class__.__name__} does not have 'tie_word_embeddings' attribute. "
"This may cause issues with weight tying."
)
tie_word_embeddings = getattr(text_config, "tie_word_embeddings", None)
Comment on lines +2257 to +2263
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think all multimodals rely on the text_config here unfortunately, do they? I.e. what we want is that each model decide on its own if it should tie its own weights. So we don't really want to check the text_config...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it's the opposite of what we discussed above. In that case, what do you suggest?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the issue is that if you have a submodel such as self.text_model = AutoModel._from_config(text_config), you cannot know its weights in advance as it can be any model. So that model is the only responsible for its own weight.
We cannot delegate for the top-most model, as this makes it wayyy too hard and not scalable in general

tie_encoder_decoder = getattr(self.config, "tie_encoder_decoder", False)
should_tie = tie_encoder_decoder if tie_word_embeddings is None else tie_word_embeddings
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we basically disregard tie_encoder_decoder completely? Because from your change is configuration_utils, tie_word_embeddings will NEVER be None anymore if I understand correctly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I'm reverting this unfortunately, I would indeed like to get rid of tie_encoder_decoder

# If the config does not specify any tying, return empty dict
if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder:
if not should_tie:
return {}
# If None, return empty dict
elif tied_mapping is None:
Expand Down Expand Up @@ -3178,7 +3187,11 @@ def save_pretrained(
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}

# Recursively descend to find tied weight keys
_tied_weights_keys = set(_get_tied_weight_keys(self))
tied_keys_attr = getattr(self, "all_tied_weights_keys", None)
if tied_keys_attr is not None:
_tied_weights_keys = set(tied_keys_attr.keys())
else:
_tied_weights_keys = set(_get_tied_weight_keys(self))
Comment on lines +3190 to +3194
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change. IMO it's why tie_word_embeddings is not authoritative rn, it is entirely skipped if tied_weights_keys is filled

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i think we fill it in in many cases just to pass some tests even if the released models don't tie weights

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I am a bit lost here. Do we snow try to tie weights with _tied_weights_keys attribute even when config.tie_word_embeddings = False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's what I observed. Tried to get a reproducer here @zucchini-nlp @ArthurZucker

from transformers import PaliGemmaConfig, PaliGemmaForConditionalGeneration
from transformers.models.gemma.configuration_gemma import GemmaConfig
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig

vision_config = SiglipVisionConfig(
    hidden_size=32,
    intermediate_size=64,
    num_hidden_layers=1,
    num_attention_heads=2,
    image_size=32,
    patch_size=4,
    vocab_size=64,
    projection_dim=32,
).to_dict()

text_config = GemmaConfig(
    vocab_size=64,
    hidden_size=32,
    intermediate_size=64,
    num_hidden_layers=1,
    num_attention_heads=4,
    num_key_value_heads=4,
    tie_word_embeddings=False,
).to_dict()

config = PaliGemmaConfig(vision_config=vision_config, text_config=text_config)

print("Config tie flags:")
print(f"  config.tie_word_embeddings              -> {config.tie_word_embeddings}")
print(f"  config.text_config.tie_word_embeddings  -> {config.text_config.tie_word_embeddings}")

model = PaliGemmaForConditionalGeneration(config)
lm_head_weight = model.lm_head.weight
input_embed_weight = model.model.language_model.embed_tokens.weight
tied = lm_head_weight.data_ptr() == input_embed_weight.data_ptr()

print("\nModel tying details:")
print(f"  class _tied_weights_keys      -> {model._tied_weights_keys}")
print(f"  model.all_tied_weights_keys   -> {model.all_tied_weights_keys}")
print(f"  lm_head shares embedding tensor? -> {tied}")

this will show tied weights. However if you go to modeling_paligemma and change

    _tied_weights_keys = {}# {"lm_head.weight": "model.language_model.embed_tokens.weight"}

then it will work and weights will not be tied. I think the solution is to use all_tied_weights_keys instead of _get_tied_weights_keys here. It's what I understand from the logic flow

and for instance for mistral3 it breaks the model generation, but for several others as well, this is why I wanted the run-slow upgrade mentioned hehe.

There might be a simpler fix, not seeing it right now

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact here is that I think some older models were tying using _tie_weights (the one with underscore) that were ALWAYS run, independently of the configs. So they could have tied weights even though the config say to not tie technically. In save_pretrained, we check the pointers of the weight tensors, so we know the tied weights from there and we cannot be wrong about it. But then if we use the good citizen all_tied_weights_keys, which correctly looks at the config, some of those tied weights will not find their source as they are not SUPPOSED to be tied (but they were anyway).

So technically looking at all potential _tied_weights_keys here is not wrong as we check pointers anyway, and this avoids this kind of issues.
But indeed, in the future we want to always know which weights are tied and simply rely on the internal list (or even better, recomputing it with get_expanded_tied_weights_keys(all_submodels=True))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So technically looking at all potential _tied_weights_keys here is not wrong as we check pointers anyway, and this avoids this kind of issues.

Hmm ok, but then why are models like mistral3 broken by the current setup? Ah because you mean we did it in save_pretrained but load_pretrained has to do with what exists now?

Not super sure of the way forward there, can you elaborate a bit?

error_names = []
to_delete_names = set()
for names in shared_ptrs.values():
Expand Down Expand Up @@ -4410,7 +4423,9 @@ def _move_missing_keys_from_meta_to_cpu(
# The tied weight keys are in the "missing" usually, but they should not be moved (they will be tied anyway)
# This is especially important because if they are moved, they will lose the `_is_hf_initialized` flag, and they
# will be re-initialized for nothing (which can be quite long)
for key in missing_keys - self.all_tied_weights_keys.keys():
tied_keys_attr = getattr(self, "all_tied_weights_keys", {}) or {}
tied_keys = set(tied_keys_attr.keys())
for key in missing_keys - tied_keys:
Comment on lines -4413 to +4428
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have this? all_tied_weights_keys is guaranteed to be a dict already at this point from post_init

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes indeed, could be just

tied_keys = set(self.all_tied_weights_keys.keys())
  for key in missing_keys - tied_keys:
 do stuff

will update

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to cast it again as a set the keys already are a subclass of set!

param = model_state_dict[key]
# Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
if param.device == torch.device("meta"):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/fsmt/configuration_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(
bos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
num_hidden_layers=encoder_layers,
tie_word_embeddings=tie_word_embeddings,
)
if "decoder" in common_kwargs:
del common_kwargs["decoder"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

@auto_docstring
class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
_keep_in_fp32_modules_strict = ["codec_model"]
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/musicgen/configuration_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,5 +221,11 @@ def __init__(self, text_encoder, audio_encoder, decoder, **kwargs):
def sampling_rate(self):
return self.audio_encoder.sampling_rate

# overriding these because they crash - not 100% sure of that one
def get_text_config(self, decoder=None, encoder=None):
if decoder is None and encoder is None:
decoder = True
return super().get_text_config(decoder=decoder, encoder=encoder)


__all__ = ["MusicgenConfig", "MusicgenDecoderConfig"]
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,11 @@ def __init__(
def sampling_rate(self):
return self.audio_encoder.sampling_rate

# overriding these because they crash - not 100% sure of that one
def get_text_config(self, decoder=None, encoder=None):
if decoder is None and encoder is None:
decoder = True
return super().get_text_config(decoder=decoder, encoder=encoder)


__all__ = ["MusicgenMelodyConfig", "MusicgenMelodyDecoderConfig"]
3 changes: 3 additions & 0 deletions tests/models/fsmt/test_modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def get_config(self):
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
tie_word_embeddings=True,
)

def prepare_config_and_inputs_for_common(self):
Expand Down Expand Up @@ -254,6 +255,7 @@ def test_ensure_weights_are_shared(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()

config.tie_word_embeddings = True
config.decoder.tie_word_embeddings = True
model = FSMTForConditionalGeneration(config)

# FSMT shares three weights.
Expand All @@ -270,6 +272,7 @@ def test_ensure_weights_are_shared(self):
)

config.tie_word_embeddings = False
config.decoder.tie_word_embeddings = False
model = FSMTForConditionalGeneration(config)

# FSMT shares three weights.
Expand Down
43 changes: 43 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2100,6 +2100,49 @@ def test_tied_weights_keys(self):
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)

def test_tie_word_embeddings_is_authoritative(self):
original_config, _ = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
tied_config = copy.deepcopy(original_config)
tied_config.get_text_config().tie_word_embeddings = True

untied_config = copy.deepcopy(original_config)
untied_config.get_text_config().tie_word_embeddings = False

model_tied = model_class(tied_config)
model_untied = model_class(untied_config)

if not hasattr(model_tied, "_tied_weights_keys") or not model_tied._tied_weights_keys:
continue
Comment on lines +2116 to +2117
Copy link
Contributor Author

@molbap molbap Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note from huddle: should recurse through all submodels here to get the tied keys


tied_keys = model_tied._tied_weights_keys
state_dict_tied = model_tied.state_dict()
state_dict_untied = model_untied.state_dict()

for target_key, source_key in tied_keys.items():
if target_key not in state_dict_tied or source_key not in state_dict_tied:
continue
if target_key not in state_dict_untied or source_key not in state_dict_untied:
continue

target_tied_ptr = id_tensor_storage(state_dict_tied[target_key])
source_tied_ptr = id_tensor_storage(state_dict_tied[source_key])
target_untied_ptr = id_tensor_storage(state_dict_untied[target_key])
source_untied_ptr = id_tensor_storage(state_dict_untied[source_key])

self.assertEqual(
target_tied_ptr,
source_tied_ptr,
f"{model_class}: With tie_word_embeddings=True, '{target_key}' should share storage with '{source_key}'",
)
self.assertNotEqual(
target_untied_ptr,
source_untied_ptr,
f"{model_class}: With tie_word_embeddings=False, '{target_key}' should NOT share storage with '{source_key}'. "
f"Config tie_word_embeddings must be authoritative over class-level _tied_weights_keys.",
)

def test_model_weights_reload_no_missing_tied_weights(self):
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down