1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ from __future__ import annotations
17+
1618from copy import deepcopy
19+ from typing import TYPE_CHECKING
1720
1821from .core_model_loading import Concatenate , MergeModulelist , WeightConverter , WeightRenaming
1922from .utils import is_torch_available
2326 import torch
2427
2528
29+ if TYPE_CHECKING :
30+ from .modeling_utils import PreTrainedModel
31+ from .quantizers import HfQuantizer
32+
33+
2634def _build_checkpoint_conversion_mapping ():
2735 mapping = {
2836 "mixtral" : [
2937 WeightRenaming (".block_sparse_moe.gate" , ".mlp.gate" ),
3038 WeightConverter (
31- source_keys = [
39+ source_patterns = [
3240 "block_sparse_moe.experts.*.w1.weight" ,
3341 "block_sparse_moe.experts.*.w3.weight" ,
3442 ], # you give me a list of 2 keys, I collect a list of a list of tensors
35- target_keys = "mlp.experts.gate_up_proj" , # target key gets the list of two tensors
43+ target_patterns = "mlp.experts.gate_up_proj" , # target key gets the list of two tensors
3644 operations = [
3745 MergeModulelist (
3846 dim = 0
@@ -41,10 +49,10 @@ def _build_checkpoint_conversion_mapping():
4149 ], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
4250 ),
4351 WeightConverter (
44- source_keys = [
52+ source_patterns = [
4553 "block_sparse_moe.experts.*.w2.weight" ,
4654 ],
47- target_keys = "mlp.experts.down_proj" , # target key gets the list of two tensors
55+ target_patterns = "mlp.experts.down_proj" , # target key gets the list of two tensors
4856 operations = [
4957 MergeModulelist (
5058 dim = 0
@@ -54,50 +62,58 @@ def _build_checkpoint_conversion_mapping():
5462 ],
5563 "qwen2_moe" : [
5664 WeightConverter (
57- source_keys = [
65+ source_patterns = [
5866 "mlp.experts.*.gate_proj.weight" ,
5967 "mlp.experts.*.up_proj.weight" ,
6068 ],
61- target_keys = "mlp.experts.gate_up_proj" ,
69+ target_patterns = "mlp.experts.gate_up_proj" ,
6270 operations = [MergeModulelist (dim = 0 ), Concatenate (dim = 1 )],
6371 ),
6472 WeightConverter (
65- source_keys = ["mlp.experts.*.down_proj.weight" ],
66- target_keys = "mlp.experts.down_proj" ,
73+ source_patterns = ["mlp.experts.*.down_proj.weight" ],
74+ target_patterns = "mlp.experts.down_proj" ,
6775 operations = [MergeModulelist (dim = 0 )],
6876 ),
6977 ],
78+ "timm_wrapper" : [
79+ # Simply add the prefix `timm_model`
80+ # TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming
81+ WeightRenaming (
82+ source_patterns = r"(.+)" ,
83+ target_patterns = r"timm_model.\1" ,
84+ )
85+ ],
7086 "legacy" : [
7187 WeightRenaming (
72- source_keys = "LayerNorm.gamma" ,
73- target_keys = "LayerNorm.weight" ,
88+ source_patterns = "LayerNorm.gamma" ,
89+ target_patterns = "LayerNorm.weight" ,
7490 ),
7591 WeightRenaming (
76- source_keys = "LayerNorm.beta" ,
77- target_keys = "LayerNorm.bias" ,
92+ source_patterns = "LayerNorm.beta" ,
93+ target_patterns = "LayerNorm.bias" ,
7894 ),
7995 ],
8096 }
8197 if hasattr (torch .nn .utils .parametrizations , "weight_norm" ):
8298 mapping ["legacy" ] += [
8399 WeightRenaming (
84- source_keys = "weight_g" ,
85- target_keys = "parametrizations.weight.original0" ,
100+ source_patterns = "weight_g" ,
101+ target_patterns = "parametrizations.weight.original0" ,
86102 ),
87103 WeightRenaming (
88- source_keys = "weight_v" ,
89- target_keys = "parametrizations.weight.original1" ,
104+ source_patterns = "weight_v" ,
105+ target_patterns = "parametrizations.weight.original1" ,
90106 ),
91107 ]
92108 else :
93109 mapping ["legacy" ] += [
94110 WeightRenaming (
95- source_keys = "parametrizations.weight.original0" ,
96- target_keys = "weight_g" ,
111+ source_patterns = "parametrizations.weight.original0" ,
112+ target_patterns = "weight_g" ,
97113 ),
98114 WeightRenaming (
99- source_keys = "parametrizations.weight.original1" ,
100- target_keys = "weight_v" ,
115+ source_patterns = "parametrizations.weight.original1" ,
116+ target_patterns = "weight_v" ,
101117 ),
102118 ]
103119
@@ -127,5 +143,72 @@ def _build_checkpoint_conversion_mapping():
127143def get_checkpoint_conversion_mapping (model_type ):
128144 global _checkpoint_conversion_mapping_cache
129145 _checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping ()
130- globals ()["_checkpoint_conversion_mapping" ] = _checkpoint_conversion_mapping_cache
131- return deepcopy (_checkpoint_conversion_mapping_cache .get (model_type , None ))
146+ return deepcopy (_checkpoint_conversion_mapping_cache .get (model_type ))
147+
148+
149+ # DO NOT MODIFY, KEPT FOR BC ONLY
150+ VLMS = [
151+ "aria" ,
152+ "ayavision" ,
153+ "colpali" ,
154+ "emu3" ,
155+ "fuyu" ,
156+ "gotocr2" ,
157+ "gemma3" ,
158+ "internvl" ,
159+ "llava" , # all llava prefixed models fall under this check
160+ "mistral3" ,
161+ "mllama" ,
162+ "paligemma" ,
163+ "shieldgemma2" ,
164+ "qwen2vl" ,
165+ "qwen2_5_vl" ,
166+ "videollava" ,
167+ "vipllava" ,
168+ "sam3_video" ,
169+ "sam3" ,
170+ "sam3_tracker" ,
171+ "sam3_tracker_video" ,
172+ ]
173+
174+
175+ def get_model_conversion_mapping (
176+ model : PreTrainedModel ,
177+ key_mapping : dict [str , str ] | None = None ,
178+ hf_quantizer : HfQuantizer | None = None ,
179+ add_legacy : bool = True ,
180+ ) -> list [WeightConverter | WeightRenaming ]:
181+ """
182+ For a given `model`, obtain the weight conversion mapping if any are registered either as a simple renaming
183+ `_checkpoint_conversion_mapping` class argument, or in the general WeightConverter mapping.
184+ """
185+ weight_conversions = []
186+
187+ # Load models with key mapping
188+ if key_mapping is not None :
189+ weight_conversions = [WeightRenaming (source_patterns = k , target_patterns = v ) for k , v in key_mapping .items ()]
190+ elif any (
191+ allowed_name in class_name .__name__ .lower ()
192+ for class_name in model .__class__ .__mro__ [:- 1 ]
193+ for allowed_name in VLMS
194+ ):
195+ weight_conversions = [
196+ WeightRenaming (source_patterns = k , target_patterns = v )
197+ for k , v in model ._checkpoint_conversion_mapping .items ()
198+ ]
199+
200+ # TODO: should be checked recursively on submodels!!
201+ model_type = getattr (model .config , "model_type" , None )
202+ if model_type is not None :
203+ model_specific_conversions = get_checkpoint_conversion_mapping (model_type )
204+ if model_specific_conversions is not None :
205+ weight_conversions .extend (model_specific_conversions )
206+
207+ if add_legacy :
208+ weight_conversions .extend (get_checkpoint_conversion_mapping ("legacy" ))
209+
210+ # Add the ones from the quantizer as well if provided
211+ if hf_quantizer is not None :
212+ weight_conversions .extend (hf_quantizer .get_weight_conversions ())
213+
214+ return weight_conversions
0 commit comments