@@ -105,7 +105,7 @@ class ConversionOps:
105105 """Base class for weight conversion operations."""
106106
107107 def __repr__ (self ):
108- return f"{ self .__class__ .__name__ } (dim={ self .dim } ) "
108+ return f"{ self .__class__ .__name__ } (dim={ self .dim } , { self . reverse_op } "
109109
110110 @abstractmethod
111111 def convert (
@@ -637,6 +637,44 @@ def repl(m, repl_map: dict[str, str]) -> str:
637637 return replacement
638638
639639
640+ def rename_source_key (
641+ source_key : str ,
642+ rename_alternation : re .Pattern ,
643+ rename_by_group : dict ,
644+ weight_pattern_alternation : re .Pattern ,
645+ weight_pattern_by_group : dict ,
646+ prefix : str | None = None ,
647+ meta_state_dict : dict | None = None ,
648+ ) -> tuple [str , re .Match | None ]:
649+ """
650+ Rename a source key given all the renaming and weight conversion patterns we have. Also takes care of adding/removing
651+ the base model prefix during loading if necesary.
652+ """
653+ # 1. apply all renamings
654+ renamed_key = rename_alternation .sub (lambda m : repl (m , rename_by_group ), source_key ).replace ("\\ " , "" )
655+
656+ # 2. apply renaming through weight conversions on the key if we have any WeightConverter
657+ matched_converter_pattern = (
658+ weight_pattern_alternation .search (renamed_key ) if weight_pattern_alternation is not None else None
659+ )
660+ if matched_converter_pattern is not None :
661+ renamed_key = weight_pattern_alternation .sub (lambda m : repl (m , weight_pattern_by_group ), renamed_key ).replace (
662+ "\\ " , ""
663+ )
664+
665+ # 3. check if we need to add or remove prefix if necesary (only during loading, not saving)
666+ if prefix is not None and meta_state_dict is not None :
667+ if (
668+ renamed_key .startswith (prefix )
669+ and meta_state_dict .get (re .sub (f"^{ prefix } ." , "" , renamed_key , count = 1 )) is not None
670+ ):
671+ renamed_key = re .sub (f"^{ prefix } ." , "" , renamed_key , count = 1 )
672+ elif meta_state_dict .get (f"{ prefix } .{ renamed_key } " ) is not None :
673+ renamed_key = f"{ prefix } .{ renamed_key } "
674+
675+ return renamed_key , matched_converter_pattern
676+
677+
640678def convert_and_load_state_dict_in_model (
641679 model : PreTrainedModel ,
642680 state_dict : dict [str , Any ],
@@ -762,6 +800,7 @@ def convert_and_load_state_dict_in_model(
762800 # build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
763801 # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
764802 rename_alt , _ , rename_by_group = build_glob_alternation (renamings )
803+ weight_pattern_alt , src_group_to_glob , tgt_group_to_glob = None , None , None
765804 if converters != []:
766805 weight_pattern_alt , src_group_to_glob , tgt_group_to_glob = build_glob_alternation (converters )
767806 if tp_plan != {}:
@@ -773,24 +812,18 @@ def convert_and_load_state_dict_in_model(
773812
774813 state_dict = sorted (state_dict .items (), key = lambda kv : dot_natural_key (kv [0 ]))
775814 for original_key , tensor in state_dict :
776- # 1. apply all renamings
777- renamed_key = rename_alt .sub (lambda m : repl (m , rename_by_group ), original_key ).replace ("\\ " , "" )
778-
779- # 2. apply 1 weight conversion on the key
780- matched_pattern = weight_pattern_alt .search (renamed_key ) if converters != [] else None
781- if matched_pattern is not None : # we have a converter to apply
782- renamed_key = weight_pattern_alt .sub (lambda m : repl (m , tgt_group_to_glob ), renamed_key ).replace ("\\ " , "" )
783-
784- # 3. check if we need to add or remove prefix
785- if (
786- renamed_key .startswith (prefix )
787- and meta_model_state_dict .get (re .sub (f"^{ prefix } ." , "" , renamed_key , count = 1 )) is not None
788- ):
789- renamed_key = re .sub (f"^{ prefix } ." , "" , renamed_key , count = 1 )
790- elif meta_model_state_dict .get (f"{ prefix } .{ renamed_key } " ) is not None :
791- renamed_key = f"{ prefix } .{ renamed_key } "
815+ # 1. Rename the key according to all renaming pattern and optional weight converter patterns
816+ renamed_key , matched_pattern = rename_source_key (
817+ original_key ,
818+ rename_alt ,
819+ rename_by_group ,
820+ weight_pattern_alt ,
821+ tgt_group_to_glob ,
822+ prefix ,
823+ meta_model_state_dict ,
824+ )
792825
793- # 4 . finally, collect the tensor into the proper converter
826+ # 2 . finally, collect the tensor into the proper converter
794827 if renamed_key in missing_keys :
795828 empty_param = meta_model_state_dict .get (renamed_key )
796829 if matched_pattern :
@@ -802,7 +835,7 @@ def convert_and_load_state_dict_in_model(
802835 mapping = param_name_to_load .setdefault (renamed_key , WeightRenaming (original_key , renamed_key ))
803836 source_pattern = original_key
804837
805- # 5 . Handle dtype casting
838+ # 3 . Handle dtype casting
806839 if (
807840 hf_quantizer
808841 and not hf_quantizer .pre_quantized
@@ -822,7 +855,7 @@ def convert_and_load_state_dict_in_model(
822855 elif empty_param is not None and empty_param .dtype != _dtype :
823856 _dtype = empty_param .dtype # usually correct when initializing
824857
825- # 6 . Handle TP sharding or device_map placement -> scheduled materialization
858+ # 4 . Handle TP sharding or device_map placement -> scheduled materialization
826859 future = None
827860 if device_mesh :
828861 if matched_tp_pattern := tp_plan_alt .search (renamed_key ):
@@ -936,20 +969,17 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
936969 # build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
937970 # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
938971 rename_alt , _ , rename_by_group = build_glob_alternation (renamings )
972+ weight_pattern_alt , src_group_to_glob , tgt_group_to_glob = None , None , None
939973 if converters != []:
940974 weight_pattern_alt , src_group_to_glob , tgt_group_to_glob = build_glob_alternation (converters )
941975
942976 state_dict = sorted (state_dict .items (), key = lambda kv : dot_natural_key (kv [0 ]))
943977 for original_key , tensor in state_dict :
944- # apply all renamings
945- renamed_key = rename_alt .sub (lambda m : repl (m , rename_by_group ), original_key ).replace ("\\ " , "" )
946-
947- # apply weight conversion on the key
948- matched_pattern = weight_pattern_alt .search (renamed_key ) if converters != [] else None
978+ # Rename the key according to all renaming pattern and optional weight converter patterns
979+ renamed_key , matched_pattern = rename_source_key (
980+ original_key , rename_alt , rename_by_group , weight_pattern_alt , tgt_group_to_glob
981+ )
949982 if matched_pattern is not None :
950- renamed_key = weight_pattern_alt .sub (lambda m : repl (m , tgt_group_to_glob ), renamed_key ).replace (
951- "\\ " , ""
952- )
953983 new_converter = deepcopy (pattern_to_converter [src_group_to_glob [matched_pattern .lastgroup ]])
954984 # each target key gets its own converter instance
955985 mapping = conversion_mapping .setdefault (renamed_key , new_converter )
0 commit comments