Skip to content

Commit 8cda0fd

Browse files
committed
utils to rename keys
1 parent 1a240e6 commit 8cda0fd

File tree

1 file changed

+58
-28
lines changed

1 file changed

+58
-28
lines changed

src/transformers/core_model_loading.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class ConversionOps:
105105
"""Base class for weight conversion operations."""
106106

107107
def __repr__(self):
108-
return f"{self.__class__.__name__}(dim={self.dim})"
108+
return f"{self.__class__.__name__}(dim={self.dim}, {self.reverse_op}"
109109

110110
@abstractmethod
111111
def convert(
@@ -637,6 +637,44 @@ def repl(m, repl_map: dict[str, str]) -> str:
637637
return replacement
638638

639639

640+
def rename_source_key(
641+
source_key: str,
642+
rename_alternation: re.Pattern,
643+
rename_by_group: dict,
644+
weight_pattern_alternation: re.Pattern,
645+
weight_pattern_by_group: dict,
646+
prefix: str | None = None,
647+
meta_state_dict: dict | None = None,
648+
) -> tuple[str, re.Match | None]:
649+
"""
650+
Rename a source key given all the renaming and weight conversion patterns we have. Also takes care of adding/removing
651+
the base model prefix during loading if necesary.
652+
"""
653+
# 1. apply all renamings
654+
renamed_key = rename_alternation.sub(lambda m: repl(m, rename_by_group), source_key).replace("\\", "")
655+
656+
# 2. apply renaming through weight conversions on the key if we have any WeightConverter
657+
matched_converter_pattern = (
658+
weight_pattern_alternation.search(renamed_key) if weight_pattern_alternation is not None else None
659+
)
660+
if matched_converter_pattern is not None:
661+
renamed_key = weight_pattern_alternation.sub(lambda m: repl(m, weight_pattern_by_group), renamed_key).replace(
662+
"\\", ""
663+
)
664+
665+
# 3. check if we need to add or remove prefix if necesary (only during loading, not saving)
666+
if prefix is not None and meta_state_dict is not None:
667+
if (
668+
renamed_key.startswith(prefix)
669+
and meta_state_dict.get(re.sub(f"^{prefix}.", "", renamed_key, count=1)) is not None
670+
):
671+
renamed_key = re.sub(f"^{prefix}.", "", renamed_key, count=1)
672+
elif meta_state_dict.get(f"{prefix}.{renamed_key}") is not None:
673+
renamed_key = f"{prefix}.{renamed_key}"
674+
675+
return renamed_key, matched_converter_pattern
676+
677+
640678
def convert_and_load_state_dict_in_model(
641679
model: PreTrainedModel,
642680
state_dict: dict[str, Any],
@@ -762,6 +800,7 @@ def convert_and_load_state_dict_in_model(
762800
# build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
763801
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
764802
rename_alt, _, rename_by_group = build_glob_alternation(renamings)
803+
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = None, None, None
765804
if converters != []:
766805
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = build_glob_alternation(converters)
767806
if tp_plan != {}:
@@ -773,24 +812,18 @@ def convert_and_load_state_dict_in_model(
773812

774813
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
775814
for original_key, tensor in state_dict:
776-
# 1. apply all renamings
777-
renamed_key = rename_alt.sub(lambda m: repl(m, rename_by_group), original_key).replace("\\", "")
778-
779-
# 2. apply 1 weight conversion on the key
780-
matched_pattern = weight_pattern_alt.search(renamed_key) if converters != [] else None
781-
if matched_pattern is not None: # we have a converter to apply
782-
renamed_key = weight_pattern_alt.sub(lambda m: repl(m, tgt_group_to_glob), renamed_key).replace("\\", "")
783-
784-
# 3. check if we need to add or remove prefix
785-
if (
786-
renamed_key.startswith(prefix)
787-
and meta_model_state_dict.get(re.sub(f"^{prefix}.", "", renamed_key, count=1)) is not None
788-
):
789-
renamed_key = re.sub(f"^{prefix}.", "", renamed_key, count=1)
790-
elif meta_model_state_dict.get(f"{prefix}.{renamed_key}") is not None:
791-
renamed_key = f"{prefix}.{renamed_key}"
815+
# 1. Rename the key according to all renaming pattern and optional weight converter patterns
816+
renamed_key, matched_pattern = rename_source_key(
817+
original_key,
818+
rename_alt,
819+
rename_by_group,
820+
weight_pattern_alt,
821+
tgt_group_to_glob,
822+
prefix,
823+
meta_model_state_dict,
824+
)
792825

793-
# 4. finally, collect the tensor into the proper converter
826+
# 2. finally, collect the tensor into the proper converter
794827
if renamed_key in missing_keys:
795828
empty_param = meta_model_state_dict.get(renamed_key)
796829
if matched_pattern:
@@ -802,7 +835,7 @@ def convert_and_load_state_dict_in_model(
802835
mapping = param_name_to_load.setdefault(renamed_key, WeightRenaming(original_key, renamed_key))
803836
source_pattern = original_key
804837

805-
# 5. Handle dtype casting
838+
# 3. Handle dtype casting
806839
if (
807840
hf_quantizer
808841
and not hf_quantizer.pre_quantized
@@ -822,7 +855,7 @@ def convert_and_load_state_dict_in_model(
822855
elif empty_param is not None and empty_param.dtype != _dtype:
823856
_dtype = empty_param.dtype # usually correct when initializing
824857

825-
# 6. Handle TP sharding or device_map placement -> scheduled materialization
858+
# 4. Handle TP sharding or device_map placement -> scheduled materialization
826859
future = None
827860
if device_mesh:
828861
if matched_tp_pattern := tp_plan_alt.search(renamed_key):
@@ -936,20 +969,17 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
936969
# build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
937970
# and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
938971
rename_alt, _, rename_by_group = build_glob_alternation(renamings)
972+
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = None, None, None
939973
if converters != []:
940974
weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = build_glob_alternation(converters)
941975

942976
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
943977
for original_key, tensor in state_dict:
944-
# apply all renamings
945-
renamed_key = rename_alt.sub(lambda m: repl(m, rename_by_group), original_key).replace("\\", "")
946-
947-
# apply weight conversion on the key
948-
matched_pattern = weight_pattern_alt.search(renamed_key) if converters != [] else None
978+
# Rename the key according to all renaming pattern and optional weight converter patterns
979+
renamed_key, matched_pattern = rename_source_key(
980+
original_key, rename_alt, rename_by_group, weight_pattern_alt, tgt_group_to_glob
981+
)
949982
if matched_pattern is not None:
950-
renamed_key = weight_pattern_alt.sub(lambda m: repl(m, tgt_group_to_glob), renamed_key).replace(
951-
"\\", ""
952-
)
953983
new_converter = deepcopy(pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]])
954984
# each target key gets its own converter instance
955985
mapping = conversion_mapping.setdefault(renamed_key, new_converter)

0 commit comments

Comments
 (0)