Skip to content

Commit 078ff68

Browse files
zucchini-nlpvasqu
andauthored
🚨 Move rotary_partial_emb to RopeParams and delete unnecessary code 🔪 (#42255)
* tmp * batch push * maybe better pop and break, and we'll have one theta per config in the rope dict * update a few models? * fix tests that are easu first * dont overwrite if already present!!! * partial rotary factor * more fixes to the god of fixes * setdefault * fix copies * Update src/transformers/modeling_rope_utils.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * Update src/transformers/models/efficientloftr/configuration_efficientloftr.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * attempt one * update all models * fix tests * fix tests * oops * fix slow tests with nested rope models * fix copies * deal with circular import and move the mixin to base config class * fix copies * fix a few tests * update the migration guide --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
1 parent 554fb40 commit 078ff68

File tree

148 files changed

+1509
-2013
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

148 files changed

+1509
-2013
lines changed

‎MIGRATION_GUIDE_V5.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ model_4bit = AutoModelForCausalLM.from_pretrained(
328328

329329
- Methods to init a nested config such as `from_xxx_config` are deleted. Configs can be init from the `__init__` method in the same way. See [#41314](https://github.com/huggingface/transformers/pull/41314).
330330
- It is no longer possible to load a config class from a URL file. Configs must be loaded from either a local path or a repo on the Hub. See [#42383](https://github.com/huggingface/transformers/pull/42383).
331-
- All parameters for configuring model's rotary embedding are now stored under `mode.rope_parameters`, including the `rope_theta` and `rope_type`. Model's `config.rope_parameters` is a simple dictionaty in most cases, and can also be a nested dict in special cases (i.e. Gemma3 and ModernBert) with different rope parameterization for each layer type. See [#39847](https://github.com/huggingface/transformers/pull/39847)
331+
- All parameters for configuring model's rotary embedding are now stored under `mode.rope_parameters`, including the `rope_theta` and `rope_type`. Model's `config.rope_parameters` is a simple dictionaty in most cases, and can also be a nested dict in special cases (i.e. Gemma3 and ModernBert) with different rope parameterization for each layer type. Trying to get `config.rope_theta` will throw an attribute error from now on. See [#39847](https://github.com/huggingface/transformers/pull/39847) and [#42255](https://github.com/huggingface/transformers/pull/42255)
332332
- Qwen-VL family configuration is in a nested format and trying to access keys directly will throw an error (e.g. `config.vocab_size`). Users are expected to access keys from their respective sub-configs (`config.text_config.vocab_size`).
333333
- Configurations of non-generative models (any model that doesn't call `model.generate()`) will no longer have a `generation_config` and `model.config.generation_config` will throw an attribute error.
334334

‎src/transformers/configuration_utils.py‎

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from . import __version__
2727
from .dynamic_module_utils import custom_object_save
2828
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
29+
from .modeling_rope_utils import RotaryEmbeddingConfigMixin
2930
from .utils import (
3031
CONFIG_NAME,
3132
PushToHubMixin,
@@ -49,7 +50,7 @@
4950
SpecificPreTrainedConfigType = TypeVar("SpecificPreTrainedConfigType", bound="PreTrainedConfig")
5051

5152

52-
class PreTrainedConfig(PushToHubMixin):
53+
class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
5354
# no-format
5455
r"""
5556
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
@@ -261,6 +262,13 @@ def __init__(
261262

262263
dtype = getattr(torch, dtype)
263264

265+
# BC for rotary embeddings. We will pop out legacy keys from kwargs and rename to new format
266+
if hasattr(self, "rope_parameters"):
267+
ignore_keys_at_rope_validation = kwargs.pop("ignore_keys_at_rope_validation", None)
268+
kwargs = self.convert_rope_params_to_dict(
269+
ignore_keys_at_rope_validation=ignore_keys_at_rope_validation, **kwargs
270+
)
271+
264272
# Attributes common for all models
265273
self.return_dict = return_dict
266274
self.output_hidden_states = output_hidden_states

‎src/transformers/modeling_rope_utils.py‎

Lines changed: 316 additions & 338 deletions
Large diffs are not rendered by default.

‎src/transformers/models/apertus/configuration_apertus.py‎

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Optional
2323

2424
from ...configuration_utils import PreTrainedConfig
25-
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
25+
from ...modeling_rope_utils import RopeParameters
2626

2727

2828
class ApertusConfig(PreTrainedConfig):
@@ -99,6 +99,7 @@ class ApertusConfig(PreTrainedConfig):
9999

100100
model_type = "apertus"
101101
keys_to_ignore_at_inference = ["past_key_values"]
102+
default_theta = 12000000.0
102103
base_model_tp_plan = {
103104
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
104105
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
@@ -160,14 +161,7 @@ def __init__(
160161
self.use_cache = use_cache
161162
self.attention_bias = attention_bias
162163
self.attention_dropout = attention_dropout
163-
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
164-
rope_scaling = kwargs.pop("rope_scaling", None)
165-
self.rope_parameters = rope_scaling or rope_parameters
166-
167-
# Validate the correctness of rotary position embeddings parameters
168-
rope_theta = kwargs.get("rope_theta", 12000000.0)
169-
standardize_rope_params(self, rope_theta=rope_theta)
170-
rope_config_validation(self)
164+
self.rope_parameters = rope_parameters
171165

172166
super().__init__(
173167
pad_token_id=pad_token_id,

‎src/transformers/models/apertus/modular_apertus.py‎

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
from torch import nn
2121

2222
from ...cache_utils import Cache
23-
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
23+
from ...configuration_utils import PreTrainedConfig
24+
from ...modeling_rope_utils import RopeParameters
2425
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
2526
from ...processing_utils import Unpack
2627
from ...utils import TransformersKwargs, logging
27-
from ..llama.configuration_llama import LlamaConfig
2828
from ..llama.modeling_llama import (
2929
LlamaAttention,
3030
LlamaDecoderLayer,
@@ -43,7 +43,7 @@
4343
logger = logging.get_logger(__name__)
4444

4545

46-
class ApertusConfig(LlamaConfig):
46+
class ApertusConfig(PreTrainedConfig):
4747
r"""
4848
This is the configuration class to store the configuration of a [`ApertusModel`]. It is used to instantiate a Apertus
4949
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
@@ -116,6 +116,8 @@ class ApertusConfig(LlamaConfig):
116116
```"""
117117

118118
model_type = "apertus"
119+
keys_to_ignore_at_inference = ["past_key_values"]
120+
default_theta = 12000000.0
119121
base_model_tp_plan = {
120122
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
121123
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
@@ -124,6 +126,11 @@ class ApertusConfig(LlamaConfig):
124126
"layers.*.mlp.up_proj": "colwise",
125127
"layers.*.mlp.down_proj": "rowwise",
126128
}
129+
base_model_pp_plan = {
130+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
131+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
132+
"norm": (["hidden_states"], ["hidden_states"]),
133+
}
127134

128135
def __init__(
129136
self,
@@ -154,35 +161,33 @@ def __init__(
154161
attention_dropout: Optional[float] = 0.0,
155162
**kwargs,
156163
):
164+
self.vocab_size = vocab_size
165+
self.max_position_embeddings = max_position_embeddings
166+
self.hidden_size = hidden_size
167+
self.intermediate_size = intermediate_size
168+
self.num_hidden_layers = num_hidden_layers
169+
self.num_attention_heads = num_attention_heads
170+
171+
# for backward compatibility
172+
if num_key_value_heads is None:
173+
num_key_value_heads = num_attention_heads
174+
175+
self.num_key_value_heads = num_key_value_heads
176+
self.hidden_act = hidden_act
177+
self.initializer_range = initializer_range
178+
self.rms_norm_eps = rms_norm_eps
179+
self.use_cache = use_cache
180+
self.attention_bias = attention_bias
181+
self.attention_dropout = attention_dropout
182+
self.rope_parameters = rope_parameters
183+
157184
super().__init__(
158-
vocab_size=vocab_size,
159-
hidden_size=hidden_size,
160-
intermediate_size=intermediate_size,
161-
num_hidden_layers=num_hidden_layers,
162-
num_attention_heads=num_attention_heads,
163-
num_key_value_heads=num_key_value_heads,
164-
hidden_act=hidden_act,
165-
max_position_embeddings=max_position_embeddings,
166-
initializer_range=initializer_range,
167-
rms_norm_eps=rms_norm_eps,
168-
use_cache=use_cache,
169185
pad_token_id=pad_token_id,
170186
bos_token_id=bos_token_id,
171187
eos_token_id=eos_token_id,
172188
tie_word_embeddings=tie_word_embeddings,
173-
rope_parameters=rope_parameters,
174-
attention_bias=attention_bias,
175-
attention_dropout=attention_dropout,
176189
**kwargs,
177190
)
178-
del self.pretraining_tp
179-
del self.mlp_bias
180-
del self.head_dim
181-
182-
# Validate the correctness of rotary position embeddings parameters
183-
rope_theta = kwargs.get("rope_theta", 12000000.0)
184-
standardize_rope_params(self, rope_theta=rope_theta)
185-
rope_config_validation(self)
186191

187192

188193
class ApertusMLP(NemotronMLP):

‎src/transformers/models/arcee/configuration_arcee.py‎

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Optional
2323

2424
from ...configuration_utils import PreTrainedConfig
25-
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
25+
from ...modeling_rope_utils import RopeParameters
2626

2727

2828
class ArceeConfig(PreTrainedConfig):
@@ -163,14 +163,7 @@ def __init__(
163163
self.attention_dropout = attention_dropout
164164
self.mlp_bias = mlp_bias
165165
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
166-
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
167-
rope_scaling = kwargs.pop("rope_scaling", None)
168-
self.rope_parameters = rope_scaling or rope_parameters
169-
170-
# Validate the correctness of rotary position embeddings parameters
171-
rope_theta = kwargs.get("rope_theta", 10000.0)
172-
standardize_rope_params(self, rope_theta=rope_theta)
173-
rope_config_validation(self)
166+
self.rope_parameters = rope_parameters
174167

175168
super().__init__(
176169
pad_token_id=pad_token_id,

‎src/transformers/models/aria/configuration_aria.py‎

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from typing import Optional
2222

2323
from ...configuration_utils import PreTrainedConfig
24-
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
24+
from ...modeling_rope_utils import RopeParameters
2525
from ..auto import CONFIG_MAPPING, AutoConfig
2626

2727

@@ -168,14 +168,7 @@ def __init__(
168168
self.attention_dropout = attention_dropout
169169
self.mlp_bias = mlp_bias
170170
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
171-
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
172-
rope_scaling = kwargs.pop("rope_scaling", None)
173-
self.rope_parameters = rope_scaling or rope_parameters
174-
175-
# Validate the correctness of rotary position embeddings parameters
176-
rope_theta = kwargs.get("rope_theta", 10000.0)
177-
standardize_rope_params(self, rope_theta=rope_theta)
178-
rope_config_validation(self)
171+
self.rope_parameters = rope_parameters
179172

180173
super().__init__(
181174
pad_token_id=pad_token_id,

‎src/transformers/models/bamba/configuration_bamba.py‎

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Optional
1818

1919
from ...configuration_utils import PreTrainedConfig
20-
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
20+
from ...modeling_rope_utils import RopeParameters
2121
from ...utils import logging
2222

2323

@@ -171,16 +171,6 @@ def __init__(
171171
self.num_logits_to_keep = num_logits_to_keep
172172

173173
self.attn_layer_indices = attn_layer_indices
174-
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
175-
self.partial_rotary_factor = 0.5
176-
rope_scaling = kwargs.pop("rope_scaling", None)
177-
self.rope_parameters = rope_scaling or rope_parameters
178-
179-
# Validate the correctness of rotary position embeddings parameters
180-
rope_theta = kwargs.get("rope_theta", 10000.0)
181-
standardize_rope_params(self, rope_theta=rope_theta)
182-
rope_config_validation(self)
183-
184174
mamba_intermediate = mamba_expand * hidden_size
185175

186176
if mamba_intermediate % mamba_n_heads != 0:
@@ -203,6 +193,8 @@ def __init__(
203193
self.mamba_conv_bias = mamba_conv_bias
204194
self.mamba_proj_bias = mamba_proj_bias
205195
self.z_loss_coefficient = z_loss_coefficient
196+
self.rope_parameters = rope_parameters
197+
kwargs["partial_rotary_factor"] = 0.5 # hardcode for BC
206198

207199
super().__init__(
208200
pad_token_id=pad_token_id,

‎src/transformers/models/bitnet/configuration_bitnet.py‎

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Optional
1717

1818
from ...configuration_utils import PreTrainedConfig
19-
from ...modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
19+
from ...modeling_rope_utils import RopeParameters
2020
from ...utils import logging
2121

2222

@@ -97,6 +97,7 @@ class BitNetConfig(PreTrainedConfig):
9797

9898
model_type = "bitnet"
9999
keys_to_ignore_at_inference = ["past_key_values"]
100+
default_theta = 500000.0
100101

101102
def __init__(
102103
self,
@@ -138,14 +139,7 @@ def __init__(
138139
self.use_cache = use_cache
139140
self.attention_bias = attention_bias
140141
self.attention_dropout = attention_dropout
141-
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
142-
rope_scaling = kwargs.pop("rope_scaling", None)
143-
self.rope_parameters = rope_scaling or rope_parameters
144-
145-
# Validate the correctness of rotary position embeddings parameters
146-
rope_theta = kwargs.get("rope_theta", 500000.0)
147-
standardize_rope_params(self, rope_theta=rope_theta)
148-
rope_config_validation(self)
142+
self.rope_parameters = rope_parameters
149143

150144
super().__init__(
151145
pad_token_id=pad_token_id,

0 commit comments

Comments
 (0)