Skip to content

Commit 2a61590

Browse files
authored
Fix broken models due to modeling/weight mapping (#42468)
* lfm2 * fix * config * fixes * fix prefixes * finally * copies * oupsi * rework key renaming * fixes * fix * fix * fix * fix * fix * fix lfm2 * fix * fix more bad mappings
1 parent a4a7008 commit 2a61590

File tree

14 files changed

+184
-147
lines changed

14 files changed

+184
-147
lines changed

src/transformers/conversion_mapping.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,56 @@ def _build_checkpoint_conversion_mapping():
7070
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
7171
),
7272
WeightConverter(
73-
source_patterns=["mlp.experts.*.down_proj.weight"],
73+
source_patterns="mlp.experts.*.down_proj.weight",
7474
target_patterns="mlp.experts.down_proj",
7575
operations=[MergeModulelist(dim=0)],
7676
),
7777
],
78+
"phimoe": [
79+
WeightConverter(
80+
source_patterns=[
81+
"mlp.experts.*.w1.weight",
82+
"mlp.experts.*.w3.weight",
83+
],
84+
target_patterns="mlp.experts.gate_up_proj",
85+
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
86+
),
87+
WeightConverter(
88+
source_patterns="mlp.experts.*.w2.weight",
89+
target_patterns="mlp.experts.down_proj",
90+
operations=[MergeModulelist(dim=0)],
91+
),
92+
],
93+
"lfm2_moe": [
94+
WeightConverter(
95+
source_patterns=[
96+
"feed_forward.experts.*.w1.weight",
97+
"feed_forward.experts.*.w3.weight",
98+
],
99+
target_patterns="feed_forward.experts.gate_up_proj",
100+
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
101+
),
102+
WeightConverter(
103+
source_patterns="feed_forward.experts.*.w2.weight",
104+
target_patterns="feed_forward.experts.down_proj",
105+
operations=[MergeModulelist(dim=0)],
106+
),
107+
],
108+
"jamba": [
109+
WeightConverter(
110+
source_patterns=[
111+
"feed_forward.experts.*.gate_proj.weight",
112+
"feed_forward.experts.*.up_proj.weight",
113+
],
114+
target_patterns="feed_forward.experts.gate_up_proj",
115+
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
116+
),
117+
WeightConverter(
118+
source_patterns="feed_forward.experts.*.down_proj.weight",
119+
target_patterns="feed_forward.experts.down_proj",
120+
operations=[MergeModulelist(dim=0)],
121+
),
122+
],
78123
"timm_wrapper": [
79124
# Simply add the prefix `timm_model`
80125
# TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming
@@ -117,16 +162,13 @@ def _build_checkpoint_conversion_mapping():
117162
),
118163
]
119164

120-
mapping["phimoe"] = mapping["mixtral"].copy()
121165
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
122166
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
123-
mapping["dot1"] = mapping["qwen2_moe"].copy()
124-
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
167+
mapping["dots1"] = mapping["qwen2_moe"].copy()
168+
mapping["ernie4_5_moe"] = mapping["qwen2_moe"].copy()
125169
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
126170
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
127-
mapping["jamba"] = mapping["qwen2_moe"].copy()
128-
mapping["lfm2_moe"] = mapping["mixtral"].copy()
129-
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
171+
mapping["longcat_flash"] = mapping["qwen2_moe"].copy()
130172
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
131173
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
132174
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()

src/transformers/core_model_loading.py

Lines changed: 69 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -48,27 +48,6 @@
4848
logger = logging.get_logger(__name__)
4949

5050

51-
def compile_glob_rule(source_glob: str, target_glob: str) -> tuple[re.Pattern, str]:
52-
"""
53-
Convert a glob-style source + target into a full regex + replacement.
54-
55-
Rules:
56-
- '*' in source_glob → (.*) capture group
57-
- '*' in target_glob → \\1, \\2, ... backrefs
58-
"""
59-
regex = re.compile(source_glob)
60-
61-
counter = 0
62-
63-
def _star_to_backref(_: re.Match) -> str:
64-
nonlocal counter
65-
counter += 1
66-
return rf"\{counter}"
67-
68-
replacement = re.sub(r"\*", _star_to_backref, target_glob)
69-
return regex, replacement
70-
71-
7251
def build_glob_alternation(
7352
globs: list[Union[WeightRenaming, WeightConverter, str]],
7453
) -> tuple[re.Pattern, dict[str, str], dict[str, str]]:
@@ -303,6 +282,7 @@ def convert(
303282
class WeightTransform:
304283
source_patterns: Union[str, list[str]] = field(init=True)
305284
target_patterns: Union[str, list[str]] = field(init=True)
285+
compiled_sources: re.Pattern = field(init=False)
306286

307287
distributed_operation: Optional[TensorParallelLayer] = None
308288
quantization_operation: Optional[ConversionOps] = None
@@ -322,20 +302,27 @@ def __post_init__(self):
322302
for i, pattern in enumerate(self.target_patterns):
323303
# Some mapping contains `^` to notify start of string when matching -> remove it during reverse mapping
324304
pattern = pattern.removeprefix("^")
325-
# This is ugly but needed for reverse mapping of Qwen2.5!
326-
if r"(?!\.(language_model|visual))" in pattern:
327-
pattern = pattern.replace(r"(?!\.(language_model|visual))", "")
328-
# Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper)
305+
# Remove negative lookahead if any. This is ugly but needed for reverse mapping of Qwen2.5 and Sam3!
306+
pattern = re.sub(r"\(\?!.+\)", "", pattern)
307+
# Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3)
329308
if r"(.+)" in pattern:
330-
pattern = pattern.replace(r"(.+)", "")
309+
pattern = pattern.replace(r"(.+)", r"\1")
331310
self.target_patterns[i] = pattern
332311

333-
# We also need to check capturing groups in the sources during reverse mapping (e.g. timm_wrapper)
312+
# We also need to check capturing groups in the sources during reverse mapping (e.g. timm_wrapper, sam3)
334313
for i, pattern in enumerate(self.source_patterns):
335314
if r"\1" in pattern:
336-
pattern = pattern.replace(r"\1", "")
315+
pattern = pattern.replace(r"\1", r"(.+)")
337316
self.source_patterns[i] = pattern
338317

318+
# Construct the regex we will use to rename keys from the sources to the targets
319+
branches = []
320+
for i, source_pattern in enumerate(self.source_patterns):
321+
group_name = f"g{i}"
322+
pattern = source_pattern.replace(".*.", r"\..*\.")
323+
branches.append(f"(?P<{group_name}>{pattern})")
324+
self.compiled_sources = re.compile("|".join(branches))
325+
339326
def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future):
340327
self.collected_tensors[source_pattern].append(future)
341328
self.layer_targets[target_key].add(source_key)
@@ -344,6 +331,32 @@ def reset(self) -> None:
344331
"""Clean-up the collected tensors to make sure we don't keep references to past tensors in memory."""
345332
self.collected_tensors = defaultdict(list)
346333

334+
def rename_source_key(self, source_key: str) -> tuple[str, str | None]:
335+
"""
336+
Return a tuple (renamed_key, source_pattern_producing_the_match).
337+
Try renaming `source_key` according to the source and target patterns of the current WeightTransform.
338+
In case of a one-to-many transform, i.e. we have several target patterns, the matching source pattern
339+
will be replaced by the first of all the target patterns (they are then correctly expanded in the Operations).
340+
"""
341+
# Try matching one of the alternation branches
342+
match_object = self.compiled_sources.search(source_key)
343+
if match_object is None:
344+
return source_key, None
345+
# Find the source that produced the match (it's the first group that matched, as the search stops after first branch match)
346+
matching_group_name = next(name for name, val in match_object.groupdict().items() if val is not None)
347+
source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])]
348+
# If we matched, we always replace with the first target pattern, in case we have several (one to many transform)
349+
replacement = self.target_patterns[0]
350+
# # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
351+
if r"\1" in replacement:
352+
# The index of the internal group we need to replace is the index of the matched named group as it comes
353+
# inside that matched named group
354+
replaced_group_idx = self.compiled_sources.groupindex[matching_group_name] + 1
355+
replacement = replacement.replace(r"\1", match_object.group(replaced_group_idx))
356+
renamed_key = source_key.replace(match_object.group(0), replacement)
357+
358+
return renamed_key, source_pattern_that_matched
359+
347360
def reverse_transform(self) -> WeightTransform:
348361
"""Reverse the current `WeightTransform` instance, to be able to save with the opposite weight transformations."""
349362
# TODO: check this and relax when quantizer have `reverse_op`
@@ -613,54 +626,30 @@ class SkipLayer(Exception):
613626
pass
614627

615628

616-
def repl(m, repl_map: dict[str, str]) -> str:
617-
# Collect all groups that matched
618-
matched_groups = [name for name, val in m.groupdict().items() if val]
619-
620-
if len(matched_groups) == 0:
621-
# Should never happen
622-
return m.group(0)
623-
624-
if len(matched_groups) > 1:
625-
raise ValueError(
626-
"only a single match should happen, your regex patterns are tangled: "
627-
f"groups matched = {matched_groups} for the patternsL {repl_map.keys()}"
628-
)
629-
630-
# Exactly one match => return replacement
631-
name = matched_groups[0]
632-
replacement = repl_map[name]
633-
# Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper)
634-
if r"\1" in replacement and len(m.groups()) > 1:
635-
replacement = replacement.replace(r"\1", m.group(1))
636-
637-
return replacement
638-
639-
640629
def rename_source_key(
641630
source_key: str,
642-
rename_alternation: re.Pattern,
643-
rename_by_group: dict,
644-
weight_pattern_alternation: re.Pattern | None,
645-
weight_pattern_by_group: dict | None,
631+
weight_renamings: list[WeightRenaming],
632+
weight_converters: list[WeightConverter],
646633
prefix: str | None = None,
647634
meta_state_dict: dict | None = None,
648-
) -> tuple[str, re.Match | None]:
635+
) -> tuple[str, str | None]:
649636
"""
650637
Rename a source key given all the renaming and weight conversion patterns we have. Also takes care of adding/removing
651638
the base model prefix during loading if necesary.
652639
"""
653-
# 1. apply all renamings
654-
renamed_key = rename_alternation.sub(lambda m: repl(m, rename_by_group), source_key).replace("\\", "")
655-
656-
# 2. apply renaming through weight conversions on the key if we have any WeightConverter
657-
matched_converter_pattern = (
658-
weight_pattern_alternation.search(renamed_key) if weight_pattern_alternation is not None else None
659-
)
660-
if matched_converter_pattern is not None:
661-
renamed_key = weight_pattern_alternation.sub(lambda m: repl(m, weight_pattern_by_group), renamed_key).replace(
662-
"\\", ""
663-
)
640+
renamed_key = source_key
641+
# 1. apply all renamings in turns (if multiple match, it's the responsibility of the mappings to make sure they
642+
# are coherent)
643+
for renaming in weight_renamings:
644+
renamed_key, _ = renaming.rename_source_key(renamed_key)
645+
646+
# 2. apply renaming through weight conversions on the key if we have any WeightConverter (here we stop after
647+
# the first match, as we assume only 1 converter can match any source key)
648+
source_pattern = None
649+
for converter in weight_converters:
650+
renamed_key, source_pattern = converter.rename_source_key(renamed_key)
651+
if source_pattern is not None:
652+
break
664653

665654
# 3. check if we need to add or remove prefix if necesary (only during loading, not saving)
666655
if prefix is not None and meta_state_dict is not None:
@@ -672,7 +661,7 @@ def rename_source_key(
672661
elif meta_state_dict.get(f"{prefix}.{renamed_key}") is not None:
673662
renamed_key = f"{prefix}.{renamed_key}"
674663

675-
return renamed_key, matched_converter_pattern
664+
return renamed_key, source_pattern
676665

677666

678667
def convert_and_load_state_dict_in_model(
@@ -799,10 +788,6 @@ def convert_and_load_state_dict_in_model(
799788

800789
# build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
801790
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
802-
rename_alt, _, rename_by_group = build_glob_alternation(renamings)
803-
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = None, None, None
804-
if converters != []:
805-
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = build_glob_alternation(converters)
806791
if tp_plan != {}:
807792
tp_plan_alt, tp_plan_by_group_name, _ = build_glob_alternation(list(tp_plan.keys()))
808793
if dtype_plan != {}:
@@ -813,24 +798,19 @@ def convert_and_load_state_dict_in_model(
813798
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
814799
for original_key, tensor in state_dict:
815800
# 1. Rename the key according to all renaming pattern and optional weight converter patterns
816-
renamed_key, matched_pattern = rename_source_key(
817-
original_key,
818-
rename_alt,
819-
rename_by_group,
820-
weight_pattern_alt,
821-
tgt_group_to_glob,
822-
prefix,
823-
meta_model_state_dict,
801+
renamed_key, source_pattern = rename_source_key(
802+
original_key, renamings, converters, prefix, meta_model_state_dict
824803
)
825804

826805
# 2. finally, collect the tensor into the proper converter
827806
if renamed_key in missing_keys:
828807
empty_param = meta_model_state_dict.get(renamed_key)
829-
if matched_pattern:
830-
new_converter = deepcopy(pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]])
808+
# If we enter here, we have a WeightConverter operation to perform
809+
if source_pattern is not None:
810+
new_converter = deepcopy(pattern_to_converter[source_pattern])
831811
# each target key gets its own converter instance
832812
mapping = param_name_to_load.setdefault(renamed_key, new_converter)
833-
source_pattern = src_group_to_glob[matched_pattern.lastgroup]
813+
# Otherwise, only potential renaming
834814
else:
835815
mapping = param_name_to_load.setdefault(renamed_key, WeightRenaming(original_key, renamed_key))
836816
source_pattern = original_key
@@ -882,8 +862,8 @@ def convert_and_load_state_dict_in_model(
882862
future = spawn_materialize(thread_pool, tensor, param_device, _dtype)
883863

884864
mapping.add_tensor(renamed_key, original_key, source_pattern, future)
885-
elif matched_pattern: # add all target keys as unexpected
886-
mapping = pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]]
865+
elif source_pattern is not None: # add all target keys as unexpected
866+
mapping = pattern_to_converter[source_pattern]
887867
for k in mapping.target_patterns:
888868
unexpected_keys.add(renamed_key.replace(mapping.target_patterns[0], k))
889869
else:
@@ -964,24 +944,14 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
964944
pattern_to_converter = {k: converter for converter in converters for k in converter.source_patterns}
965945
conversion_mapping = {}
966946

967-
# build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
968-
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
969-
rename_alt, _, rename_by_group = build_glob_alternation(renamings)
970-
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = None, None, None
971-
if converters != []:
972-
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = build_glob_alternation(converters)
973-
974947
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
975948
for original_key, tensor in state_dict:
976949
# Rename the key according to all renaming pattern and optional weight converter patterns
977-
renamed_key, matched_pattern = rename_source_key(
978-
original_key, rename_alt, rename_by_group, weight_pattern_alt, tgt_group_to_glob
979-
)
980-
if matched_pattern is not None:
981-
new_converter = deepcopy(pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]])
950+
renamed_key, source_pattern = rename_source_key(original_key, renamings, converters)
951+
if source_pattern is not None:
952+
new_converter = deepcopy(pattern_to_converter[source_pattern])
982953
# each target key gets its own converter instance
983954
mapping = conversion_mapping.setdefault(renamed_key, new_converter)
984-
source_pattern = src_group_to_glob[matched_pattern.lastgroup]
985955
else:
986956
mapping = conversion_mapping.setdefault(renamed_key, WeightRenaming(original_key, renamed_key))
987957
source_pattern = original_key

src/transformers/integrations/accelerate.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -480,18 +480,15 @@ def accelerate_disk_offload(
480480
renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside
481481
`disk_offload_folder` during loading.
482482
"""
483-
from ..core_model_loading import WeightRenaming, build_glob_alternation, repl
483+
from ..core_model_loading import WeightRenaming, rename_source_key
484484

485485
if disk_offload_folder is not None:
486486
os.makedirs(disk_offload_folder, exist_ok=True)
487487
is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
488488

489-
rename = False
489+
renamings = []
490490
if weight_mapping is not None:
491491
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
492-
if len(renamings) > 0:
493-
rename = True
494-
rename_alt, _, rename_by_group = build_glob_alternation(renamings)
495492

496493
# In this case, the offload index is simply the existing safetensors (except if using custom weight loading
497494
# Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time)
@@ -505,10 +502,7 @@ def accelerate_disk_offload(
505502
weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()}
506503

507504
# Update the weight names according to the `weight_mapping`
508-
weight_renaming_map = {
509-
rename_alt.sub(lambda m: repl(m, rename_by_group), k).replace("\\", "") if rename else k: k
510-
for k in weight_map
511-
}
505+
weight_renaming_map = {rename_source_key(k, renamings, [])[0]: k for k in weight_map}
512506

513507
# Prepare the index using existing safetensors files
514508
disk_offload_index = {

0 commit comments

Comments
 (0)