4848logger = logging .get_logger (__name__ )
4949
5050
51- def compile_glob_rule (source_glob : str , target_glob : str ) -> tuple [re .Pattern , str ]:
52- """
53- Convert a glob-style source + target into a full regex + replacement.
54-
55- Rules:
56- - '*' in source_glob → (.*) capture group
57- - '*' in target_glob → \\ 1, \\ 2, ... backrefs
58- """
59- regex = re .compile (source_glob )
60-
61- counter = 0
62-
63- def _star_to_backref (_ : re .Match ) -> str :
64- nonlocal counter
65- counter += 1
66- return rf"\{ counter } "
67-
68- replacement = re .sub (r"\*" , _star_to_backref , target_glob )
69- return regex , replacement
70-
71-
7251def build_glob_alternation (
7352 globs : list [Union [WeightRenaming , WeightConverter , str ]],
7453) -> tuple [re .Pattern , dict [str , str ], dict [str , str ]]:
@@ -303,6 +282,7 @@ def convert(
303282class WeightTransform :
304283 source_patterns : Union [str , list [str ]] = field (init = True )
305284 target_patterns : Union [str , list [str ]] = field (init = True )
285+ compiled_sources : re .Pattern = field (init = False )
306286
307287 distributed_operation : Optional [TensorParallelLayer ] = None
308288 quantization_operation : Optional [ConversionOps ] = None
@@ -322,20 +302,27 @@ def __post_init__(self):
322302 for i , pattern in enumerate (self .target_patterns ):
323303 # Some mapping contains `^` to notify start of string when matching -> remove it during reverse mapping
324304 pattern = pattern .removeprefix ("^" )
325- # This is ugly but needed for reverse mapping of Qwen2.5!
326- if r"(?!\.(language_model|visual))" in pattern :
327- pattern = pattern .replace (r"(?!\.(language_model|visual))" , "" )
328- # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper)
305+ # Remove negative lookahead if any. This is ugly but needed for reverse mapping of Qwen2.5 and Sam3!
306+ pattern = re .sub (r"\(\?!.+\)" , "" , pattern )
307+ # Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3)
329308 if r"(.+)" in pattern :
330- pattern = pattern .replace (r"(.+)" , " " )
309+ pattern = pattern .replace (r"(.+)" , r"\1 " )
331310 self .target_patterns [i ] = pattern
332311
333- # We also need to check capturing groups in the sources during reverse mapping (e.g. timm_wrapper)
312+ # We also need to check capturing groups in the sources during reverse mapping (e.g. timm_wrapper, sam3 )
334313 for i , pattern in enumerate (self .source_patterns ):
335314 if r"\1" in pattern :
336- pattern = pattern .replace (r"\1" , " " )
315+ pattern = pattern .replace (r"\1" , r"(.+) " )
337316 self .source_patterns [i ] = pattern
338317
318+ # Construct the regex we will use to rename keys from the sources to the targets
319+ branches = []
320+ for i , source_pattern in enumerate (self .source_patterns ):
321+ group_name = f"g{ i } "
322+ pattern = source_pattern .replace (".*." , r"\..*\." )
323+ branches .append (f"(?P<{ group_name } >{ pattern } )" )
324+ self .compiled_sources = re .compile ("|" .join (branches ))
325+
339326 def add_tensor (self , target_key : str , source_key : str , source_pattern : str , future : Future ):
340327 self .collected_tensors [source_pattern ].append (future )
341328 self .layer_targets [target_key ].add (source_key )
@@ -344,6 +331,32 @@ def reset(self) -> None:
344331 """Clean-up the collected tensors to make sure we don't keep references to past tensors in memory."""
345332 self .collected_tensors = defaultdict (list )
346333
334+ def rename_source_key (self , source_key : str ) -> tuple [str , str | None ]:
335+ """
336+ Return a tuple (renamed_key, source_pattern_producing_the_match).
337+ Try renaming `source_key` according to the source and target patterns of the current WeightTransform.
338+ In case of a one-to-many transform, i.e. we have several target patterns, the matching source pattern
339+ will be replaced by the first of all the target patterns (they are then correctly expanded in the Operations).
340+ """
341+ # Try matching one of the alternation branches
342+ match_object = self .compiled_sources .search (source_key )
343+ if match_object is None :
344+ return source_key , None
345+ # Find the source that produced the match (it's the first group that matched, as the search stops after first branch match)
346+ matching_group_name = next (name for name , val in match_object .groupdict ().items () if val is not None )
347+ source_pattern_that_matched = self .source_patterns [int (matching_group_name [1 :])]
348+ # If we matched, we always replace with the first target pattern, in case we have several (one to many transform)
349+ replacement = self .target_patterns [0 ]
350+ # # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3)
351+ if r"\1" in replacement :
352+ # The index of the internal group we need to replace is the index of the matched named group as it comes
353+ # inside that matched named group
354+ replaced_group_idx = self .compiled_sources .groupindex [matching_group_name ] + 1
355+ replacement = replacement .replace (r"\1" , match_object .group (replaced_group_idx ))
356+ renamed_key = source_key .replace (match_object .group (0 ), replacement )
357+
358+ return renamed_key , source_pattern_that_matched
359+
347360 def reverse_transform (self ) -> WeightTransform :
348361 """Reverse the current `WeightTransform` instance, to be able to save with the opposite weight transformations."""
349362 # TODO: check this and relax when quantizer have `reverse_op`
@@ -613,54 +626,30 @@ class SkipLayer(Exception):
613626 pass
614627
615628
616- def repl (m , repl_map : dict [str , str ]) -> str :
617- # Collect all groups that matched
618- matched_groups = [name for name , val in m .groupdict ().items () if val ]
619-
620- if len (matched_groups ) == 0 :
621- # Should never happen
622- return m .group (0 )
623-
624- if len (matched_groups ) > 1 :
625- raise ValueError (
626- "only a single match should happen, your regex patterns are tangled: "
627- f"groups matched = { matched_groups } for the patternsL { repl_map .keys ()} "
628- )
629-
630- # Exactly one match => return replacement
631- name = matched_groups [0 ]
632- replacement = repl_map [name ]
633- # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper)
634- if r"\1" in replacement and len (m .groups ()) > 1 :
635- replacement = replacement .replace (r"\1" , m .group (1 ))
636-
637- return replacement
638-
639-
640629def rename_source_key (
641630 source_key : str ,
642- rename_alternation : re .Pattern ,
643- rename_by_group : dict ,
644- weight_pattern_alternation : re .Pattern | None ,
645- weight_pattern_by_group : dict | None ,
631+ weight_renamings : list [WeightRenaming ],
632+ weight_converters : list [WeightConverter ],
646633 prefix : str | None = None ,
647634 meta_state_dict : dict | None = None ,
648- ) -> tuple [str , re . Match | None ]:
635+ ) -> tuple [str , str | None ]:
649636 """
650637 Rename a source key given all the renaming and weight conversion patterns we have. Also takes care of adding/removing
651638 the base model prefix during loading if necesary.
652639 """
653- # 1. apply all renamings
654- renamed_key = rename_alternation .sub (lambda m : repl (m , rename_by_group ), source_key ).replace ("\\ " , "" )
655-
656- # 2. apply renaming through weight conversions on the key if we have any WeightConverter
657- matched_converter_pattern = (
658- weight_pattern_alternation .search (renamed_key ) if weight_pattern_alternation is not None else None
659- )
660- if matched_converter_pattern is not None :
661- renamed_key = weight_pattern_alternation .sub (lambda m : repl (m , weight_pattern_by_group ), renamed_key ).replace (
662- "\\ " , ""
663- )
640+ renamed_key = source_key
641+ # 1. apply all renamings in turns (if multiple match, it's the responsibility of the mappings to make sure they
642+ # are coherent)
643+ for renaming in weight_renamings :
644+ renamed_key , _ = renaming .rename_source_key (renamed_key )
645+
646+ # 2. apply renaming through weight conversions on the key if we have any WeightConverter (here we stop after
647+ # the first match, as we assume only 1 converter can match any source key)
648+ source_pattern = None
649+ for converter in weight_converters :
650+ renamed_key , source_pattern = converter .rename_source_key (renamed_key )
651+ if source_pattern is not None :
652+ break
664653
665654 # 3. check if we need to add or remove prefix if necesary (only during loading, not saving)
666655 if prefix is not None and meta_state_dict is not None :
@@ -672,7 +661,7 @@ def rename_source_key(
672661 elif meta_state_dict .get (f"{ prefix } .{ renamed_key } " ) is not None :
673662 renamed_key = f"{ prefix } .{ renamed_key } "
674663
675- return renamed_key , matched_converter_pattern
664+ return renamed_key , source_pattern
676665
677666
678667def convert_and_load_state_dict_in_model (
@@ -799,10 +788,6 @@ def convert_and_load_state_dict_in_model(
799788
800789 # build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
801790 # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
802- rename_alt , _ , rename_by_group = build_glob_alternation (renamings )
803- weight_pattern_alt , src_group_to_glob , tgt_group_to_glob = None , None , None
804- if converters != []:
805- weight_pattern_alt , src_group_to_glob , tgt_group_to_glob = build_glob_alternation (converters )
806791 if tp_plan != {}:
807792 tp_plan_alt , tp_plan_by_group_name , _ = build_glob_alternation (list (tp_plan .keys ()))
808793 if dtype_plan != {}:
@@ -813,24 +798,19 @@ def convert_and_load_state_dict_in_model(
813798 state_dict = sorted (state_dict .items (), key = lambda kv : dot_natural_key (kv [0 ]))
814799 for original_key , tensor in state_dict :
815800 # 1. Rename the key according to all renaming pattern and optional weight converter patterns
816- renamed_key , matched_pattern = rename_source_key (
817- original_key ,
818- rename_alt ,
819- rename_by_group ,
820- weight_pattern_alt ,
821- tgt_group_to_glob ,
822- prefix ,
823- meta_model_state_dict ,
801+ renamed_key , source_pattern = rename_source_key (
802+ original_key , renamings , converters , prefix , meta_model_state_dict
824803 )
825804
826805 # 2. finally, collect the tensor into the proper converter
827806 if renamed_key in missing_keys :
828807 empty_param = meta_model_state_dict .get (renamed_key )
829- if matched_pattern :
830- new_converter = deepcopy (pattern_to_converter [src_group_to_glob [matched_pattern .lastgroup ]])
808+ # If we enter here, we have a WeightConverter operation to perform
809+ if source_pattern is not None :
810+ new_converter = deepcopy (pattern_to_converter [source_pattern ])
831811 # each target key gets its own converter instance
832812 mapping = param_name_to_load .setdefault (renamed_key , new_converter )
833- source_pattern = src_group_to_glob [ matched_pattern . lastgroup ]
813+ # Otherwise, only potential renaming
834814 else :
835815 mapping = param_name_to_load .setdefault (renamed_key , WeightRenaming (original_key , renamed_key ))
836816 source_pattern = original_key
@@ -882,8 +862,8 @@ def convert_and_load_state_dict_in_model(
882862 future = spawn_materialize (thread_pool , tensor , param_device , _dtype )
883863
884864 mapping .add_tensor (renamed_key , original_key , source_pattern , future )
885- elif matched_pattern : # add all target keys as unexpected
886- mapping = pattern_to_converter [src_group_to_glob [ matched_pattern . lastgroup ] ]
865+ elif source_pattern is not None : # add all target keys as unexpected
866+ mapping = pattern_to_converter [source_pattern ]
887867 for k in mapping .target_patterns :
888868 unexpected_keys .add (renamed_key .replace (mapping .target_patterns [0 ], k ))
889869 else :
@@ -964,24 +944,14 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
964944 pattern_to_converter = {k : converter for converter in converters for k in converter .source_patterns }
965945 conversion_mapping = {}
966946
967- # build '(?P<g0>.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'}
968- # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched.
969- rename_alt , _ , rename_by_group = build_glob_alternation (renamings )
970- weight_pattern_alt , src_group_to_glob , tgt_group_to_glob = None , None , None
971- if converters != []:
972- weight_pattern_alt , src_group_to_glob , tgt_group_to_glob = build_glob_alternation (converters )
973-
974947 state_dict = sorted (state_dict .items (), key = lambda kv : dot_natural_key (kv [0 ]))
975948 for original_key , tensor in state_dict :
976949 # Rename the key according to all renaming pattern and optional weight converter patterns
977- renamed_key , matched_pattern = rename_source_key (
978- original_key , rename_alt , rename_by_group , weight_pattern_alt , tgt_group_to_glob
979- )
980- if matched_pattern is not None :
981- new_converter = deepcopy (pattern_to_converter [src_group_to_glob [matched_pattern .lastgroup ]])
950+ renamed_key , source_pattern = rename_source_key (original_key , renamings , converters )
951+ if source_pattern is not None :
952+ new_converter = deepcopy (pattern_to_converter [source_pattern ])
982953 # each target key gets its own converter instance
983954 mapping = conversion_mapping .setdefault (renamed_key , new_converter )
984- source_pattern = src_group_to_glob [matched_pattern .lastgroup ]
985955 else :
986956 mapping = conversion_mapping .setdefault (renamed_key , WeightRenaming (original_key , renamed_key ))
987957 source_pattern = original_key
0 commit comments