Skip to content

Commit ec3f555

Browse files
Cyrilvallez3outeille
authored andcommitted
[loading/saving] Reverse all loading operations when saving (#42396)
* first shot * default to reversing * oupso * oupsi 2 * oupsi 3 * fix renamed kwargs * fix timm_wrapper * remove fix_state_dict methods * can do it all the time, with __init__ as well * doc * oupsi * fix * create helper * fix annotation annoying isue * small fix * small fixes * alright commit all that already * oupsi * the fix * update quantizers * this works * the hardcoded regex got me hard.... * style * the final one * cleanup a bit * better * style * oupsi readded it * do it inside the ops instead - no need for full names anymore * reverse quantizers and simplify signatures * small thingy * add no_grad decorator * utils to rename keys * oupssii again * add test * simplify nicely
1 parent 14b7ac0 commit ec3f555

File tree

10 files changed

+643
-316
lines changed

10 files changed

+643
-316
lines changed

src/transformers/conversion_mapping.py

Lines changed: 105 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from __future__ import annotations
17+
1618
from copy import deepcopy
19+
from typing import TYPE_CHECKING
1720

1821
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter, WeightRenaming
1922
from .utils import is_torch_available
@@ -23,16 +26,21 @@
2326
import torch
2427

2528

29+
if TYPE_CHECKING:
30+
from .modeling_utils import PreTrainedModel
31+
from .quantizers import HfQuantizer
32+
33+
2634
def _build_checkpoint_conversion_mapping():
2735
mapping = {
2836
"mixtral": [
2937
WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"),
3038
WeightConverter(
31-
source_keys=[
39+
source_patterns=[
3240
"block_sparse_moe.experts.*.w1.weight",
3341
"block_sparse_moe.experts.*.w3.weight",
3442
], # you give me a list of 2 keys, I collect a list of a list of tensors
35-
target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors
43+
target_patterns="mlp.experts.gate_up_proj", # target key gets the list of two tensors
3644
operations=[
3745
MergeModulelist(
3846
dim=0
@@ -41,10 +49,10 @@ def _build_checkpoint_conversion_mapping():
4149
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
4250
),
4351
WeightConverter(
44-
source_keys=[
52+
source_patterns=[
4553
"block_sparse_moe.experts.*.w2.weight",
4654
],
47-
target_keys="mlp.experts.down_proj", # target key gets the list of two tensors
55+
target_patterns="mlp.experts.down_proj", # target key gets the list of two tensors
4856
operations=[
4957
MergeModulelist(
5058
dim=0
@@ -54,50 +62,58 @@ def _build_checkpoint_conversion_mapping():
5462
],
5563
"qwen2_moe": [
5664
WeightConverter(
57-
source_keys=[
65+
source_patterns=[
5866
"mlp.experts.*.gate_proj.weight",
5967
"mlp.experts.*.up_proj.weight",
6068
],
61-
target_keys="mlp.experts.gate_up_proj",
69+
target_patterns="mlp.experts.gate_up_proj",
6270
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
6371
),
6472
WeightConverter(
65-
source_keys=["mlp.experts.*.down_proj.weight"],
66-
target_keys="mlp.experts.down_proj",
73+
source_patterns=["mlp.experts.*.down_proj.weight"],
74+
target_patterns="mlp.experts.down_proj",
6775
operations=[MergeModulelist(dim=0)],
6876
),
6977
],
78+
"timm_wrapper": [
79+
# Simply add the prefix `timm_model`
80+
# TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming
81+
WeightRenaming(
82+
source_patterns=r"(.+)",
83+
target_patterns=r"timm_model.\1",
84+
)
85+
],
7086
"legacy": [
7187
WeightRenaming(
72-
source_keys="LayerNorm.gamma",
73-
target_keys="LayerNorm.weight",
88+
source_patterns="LayerNorm.gamma",
89+
target_patterns="LayerNorm.weight",
7490
),
7591
WeightRenaming(
76-
source_keys="LayerNorm.beta",
77-
target_keys="LayerNorm.bias",
92+
source_patterns="LayerNorm.beta",
93+
target_patterns="LayerNorm.bias",
7894
),
7995
],
8096
}
8197
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
8298
mapping["legacy"] += [
8399
WeightRenaming(
84-
source_keys="weight_g",
85-
target_keys="parametrizations.weight.original0",
100+
source_patterns="weight_g",
101+
target_patterns="parametrizations.weight.original0",
86102
),
87103
WeightRenaming(
88-
source_keys="weight_v",
89-
target_keys="parametrizations.weight.original1",
104+
source_patterns="weight_v",
105+
target_patterns="parametrizations.weight.original1",
90106
),
91107
]
92108
else:
93109
mapping["legacy"] += [
94110
WeightRenaming(
95-
source_keys="parametrizations.weight.original0",
96-
target_keys="weight_g",
111+
source_patterns="parametrizations.weight.original0",
112+
target_patterns="weight_g",
97113
),
98114
WeightRenaming(
99-
source_keys="parametrizations.weight.original1",
100-
target_keys="weight_v",
115+
source_patterns="parametrizations.weight.original1",
116+
target_patterns="weight_v",
101117
),
102118
]
103119

@@ -127,5 +143,72 @@ def _build_checkpoint_conversion_mapping():
127143
def get_checkpoint_conversion_mapping(model_type):
128144
global _checkpoint_conversion_mapping_cache
129145
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
130-
globals()["_checkpoint_conversion_mapping"] = _checkpoint_conversion_mapping_cache
131-
return deepcopy(_checkpoint_conversion_mapping_cache.get(model_type, None))
146+
return deepcopy(_checkpoint_conversion_mapping_cache.get(model_type))
147+
148+
149+
# DO NOT MODIFY, KEPT FOR BC ONLY
150+
VLMS = [
151+
"aria",
152+
"ayavision",
153+
"colpali",
154+
"emu3",
155+
"fuyu",
156+
"gotocr2",
157+
"gemma3",
158+
"internvl",
159+
"llava", # all llava prefixed models fall under this check
160+
"mistral3",
161+
"mllama",
162+
"paligemma",
163+
"shieldgemma2",
164+
"qwen2vl",
165+
"qwen2_5_vl",
166+
"videollava",
167+
"vipllava",
168+
"sam3_video",
169+
"sam3",
170+
"sam3_tracker",
171+
"sam3_tracker_video",
172+
]
173+
174+
175+
def get_model_conversion_mapping(
176+
model: PreTrainedModel,
177+
key_mapping: dict[str, str] | None = None,
178+
hf_quantizer: HfQuantizer | None = None,
179+
add_legacy: bool = True,
180+
) -> list[WeightConverter | WeightRenaming]:
181+
"""
182+
For a given `model`, obtain the weight conversion mapping if any are registered either as a simple renaming
183+
`_checkpoint_conversion_mapping` class argument, or in the general WeightConverter mapping.
184+
"""
185+
weight_conversions = []
186+
187+
# Load models with key mapping
188+
if key_mapping is not None:
189+
weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()]
190+
elif any(
191+
allowed_name in class_name.__name__.lower()
192+
for class_name in model.__class__.__mro__[:-1]
193+
for allowed_name in VLMS
194+
):
195+
weight_conversions = [
196+
WeightRenaming(source_patterns=k, target_patterns=v)
197+
for k, v in model._checkpoint_conversion_mapping.items()
198+
]
199+
200+
# TODO: should be checked recursively on submodels!!
201+
model_type = getattr(model.config, "model_type", None)
202+
if model_type is not None:
203+
model_specific_conversions = get_checkpoint_conversion_mapping(model_type)
204+
if model_specific_conversions is not None:
205+
weight_conversions.extend(model_specific_conversions)
206+
207+
if add_legacy:
208+
weight_conversions.extend(get_checkpoint_conversion_mapping("legacy"))
209+
210+
# Add the ones from the quantizer as well if provided
211+
if hf_quantizer is not None:
212+
weight_conversions.extend(hf_quantizer.get_weight_conversions())
213+
214+
return weight_conversions

0 commit comments

Comments
 (0)