Skip to content

Commit 6587d77

Browse files
authored
add rotary kernel support to Qwen3 model (#41147)
* add rotary kernel support to Qwen3 model Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * delete unnecessary import Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * adjust code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * adjust code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * put get rotary kernel to hub_kernels.py Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix wrong import Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * refine code and adjust related modular code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix modular mismatch bug Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update code, use lazy load kernels Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix check modular conversion issue Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix CI bug for qwen3-next Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix CI issue Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * delete unused code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * rename to `apply_rotary_transformers` Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * adjust import `lazy_load_kernel` location Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * Update modular-generated modeling files with lazy_load_kernel import location Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix conflicts Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * add more check Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * use decorator to map kernels for functions Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * small fix Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * small adjustment Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix LINT issue Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update code to adapt to new `use_kernel_func_from_hub` API in kernels Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * do not consider check_modular first Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * fix Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * add compatibility for old version `kernels` Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * add rotary fn kernel to all models Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update modular part Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * Revert "update modular part" This reverts commit b8b68c7. * update code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent 4abfd16 commit 6587d77

File tree

72 files changed

+369
-43
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+369
-43
lines changed

src/transformers/integrations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
"register_kernel_mapping",
7373
"replace_kernel_forward_from_hub",
7474
"use_kernel_forward_from_hub",
75+
"use_kernel_func_from_hub",
7576
],
7677
"integration_utils": [
7778
"INTEGRATION_TO_CALLBACK",
@@ -212,6 +213,7 @@
212213
register_kernel_mapping,
213214
replace_kernel_forward_from_hub,
214215
use_kernel_forward_from_hub,
216+
use_kernel_func_from_hub,
215217
)
216218
from .integration_utils import (
217219
INTEGRATION_TO_CALLBACK,

src/transformers/integrations/hub_kernels.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,52 @@
3232
register_kernel_mapping,
3333
replace_kernel_forward_from_hub,
3434
)
35+
from kernels import (
36+
use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub,
37+
)
38+
39+
# Try to import FuncRepository, fallback if not available
40+
try:
41+
from kernels import FuncRepository
42+
except ImportError:
43+
FuncRepository = None
44+
45+
# Try to import use_kernel_func_from_hub, fallback if not available
46+
try:
47+
from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub
48+
49+
_has_use_kernel_func_from_hub = True
50+
except ImportError:
51+
_has_use_kernel_func_from_hub = False
3552

3653
_TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper()
3754
_kernels_available = True
3855
_kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES
3956

4057
def use_kernel_forward_from_hub(layer_name: str):
4158
if _kernels_enabled:
42-
from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub
43-
4459
return _kernels_use_kernel_forward_from_hub(layer_name)
4560
else:
4661
logger.warning_once(
4762
f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
4863
)
4964
return lambda cls: cls
5065

66+
def use_kernel_func_from_hub(func_name: str):
67+
if _kernels_enabled and _has_use_kernel_func_from_hub:
68+
return _kernels_use_kernel_func_from_hub(func_name)
69+
else:
70+
if not _has_use_kernel_func_from_hub:
71+
logger.warning_once(
72+
"use_kernel_func_from_hub is not available in the installed kernels version. "
73+
"Please upgrade kernels to use this feature."
74+
)
75+
else:
76+
logger.warning_once(
77+
f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
78+
)
79+
return lambda func: func
80+
5181
_KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository]] = {
5282
"MultiScaleDeformableAttention": {
5383
"cuda": LayerRepository(
@@ -162,6 +192,16 @@ def use_kernel_forward_from_hub(layer_name: str):
162192
},
163193
}
164194

195+
# Add function kernel mappings if FuncRepository is available
196+
if FuncRepository is not None:
197+
_KERNEL_MAPPING["rotary_pos_emb"] = {
198+
"xpu": {
199+
Mode.INFERENCE: FuncRepository(
200+
repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
201+
)
202+
}
203+
}
204+
165205
def has_key(d, key):
166206
return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values())
167207

@@ -187,6 +227,12 @@ def decorator(cls):
187227

188228
return decorator
189229

230+
def use_kernel_func_from_hub(*args, **kwargs):
231+
def decorator(func):
232+
return func
233+
234+
return decorator
235+
190236
class LayerRepository:
191237
def __init__(self, *args, **kwargs):
192238
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
@@ -199,6 +245,11 @@ def replace_kernel_forward_from_hub(*args, **kwargs):
199245
def register_kernel_mapping(*args, **kwargs):
200246
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
201247

248+
def register_kernel_mapping_transformers(*args, **kwargs):
249+
raise RuntimeError(
250+
"register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`."
251+
)
252+
202253

203254
_HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = {
204255
"causal-conv1d": {"repo_id": "kernels-community/causal-conv1d"},
@@ -321,6 +372,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
321372
__all__ = [
322373
"LayerRepository",
323374
"use_kernel_forward_from_hub",
375+
"use_kernel_func_from_hub",
324376
"register_kernel_mapping",
325377
"register_kernel_mapping_transformers",
326378
"replace_kernel_forward_from_hub",

src/transformers/models/apertus/modeling_apertus.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ...activations import ACT2FN
2929
from ...cache_utils import Cache, DynamicCache
3030
from ...generation import GenerationMixin
31-
from ...integrations import use_kernel_forward_from_hub
31+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
3232
from ...masking_utils import create_causal_mask
3333
from ...modeling_layers import GenericForTokenClassification, GradientCheckpointingLayer
3434
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -147,6 +147,7 @@ def rotate_half(x):
147147
return torch.cat((-x2, x1), dim=-1)
148148

149149

150+
@use_kernel_func_from_hub("rotary_pos_emb")
150151
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
151152
"""Applies Rotary Position Embedding to the query and key tensors.
152153
@@ -237,6 +238,7 @@ def __init__(self, config: ApertusConfig, layer_idx: Optional[int] = None):
237238
self.o_proj = nn.Linear(
238239
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
239240
)
241+
self.rotary_fn = apply_rotary_pos_emb
240242
self.q_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
241243
self.k_norm = ApertusRMSNorm(self.head_dim, config.rms_norm_eps)
242244

src/transformers/models/arcee/modeling_arcee.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ...activations import ACT2FN
3131
from ...cache_utils import Cache, DynamicCache
3232
from ...generation import GenerationMixin
33-
from ...integrations import use_kernel_forward_from_hub
33+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
3434
from ...masking_utils import create_causal_mask
3535
from ...modeling_layers import (
3636
GenericForQuestionAnswering,
@@ -154,6 +154,7 @@ def rotate_half(x):
154154
return torch.cat((-x2, x1), dim=-1)
155155

156156

157+
@use_kernel_func_from_hub("rotary_pos_emb")
157158
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
158159
"""Applies Rotary Position Embedding to the query and key tensors.
159160
@@ -244,6 +245,7 @@ def __init__(self, config: ArceeConfig, layer_idx: int):
244245
self.o_proj = nn.Linear(
245246
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
246247
)
248+
self.rotary_fn = apply_rotary_pos_emb
247249

248250
def forward(
249251
self,

src/transformers/models/aria/modeling_aria.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ...activations import ACT2FN
3030
from ...cache_utils import Cache, DynamicCache
3131
from ...generation import GenerationMixin
32-
from ...integrations import use_kernel_forward_from_hub
32+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
3333
from ...masking_utils import create_causal_mask
3434
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3535
from ...modeling_layers import GradientCheckpointingLayer
@@ -378,6 +378,7 @@ def rotate_half(x):
378378
return torch.cat((-x2, x1), dim=-1)
379379

380380

381+
@use_kernel_func_from_hub("rotary_pos_emb")
381382
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
382383
"""Applies Rotary Position Embedding to the query and key tensors.
383384
@@ -468,6 +469,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int):
468469
self.o_proj = nn.Linear(
469470
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
470471
)
472+
self.rotary_fn = apply_rotary_pos_emb
471473

472474
def forward(
473475
self,

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def __init__(self, config: BambaConfig, layer_idx: int):
370370
self.o_proj = nn.Linear(
371371
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
372372
)
373+
self.rotary_fn = apply_rotary_pos_emb
373374

374375
def forward(
375376
self,

src/transformers/models/bitnet/modeling_bitnet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ...activations import ACT2FN
2828
from ...cache_utils import Cache, DynamicCache
2929
from ...generation import GenerationMixin
30-
from ...integrations import use_kernel_forward_from_hub
30+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
3131
from ...masking_utils import create_causal_mask
3232
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3333
from ...modeling_layers import GradientCheckpointingLayer
@@ -85,6 +85,7 @@ def rotate_half(x):
8585
return torch.cat((-x2, x1), dim=-1)
8686

8787

88+
@use_kernel_func_from_hub("rotary_pos_emb")
8889
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
8990
"""Applies Rotary Position Embedding to the query and key tensors.
9091
@@ -175,6 +176,7 @@ def __init__(self, config: BitNetConfig, layer_idx: int):
175176
self.o_proj = nn.Linear(
176177
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
177178
)
179+
self.rotary_fn = apply_rotary_pos_emb
178180
self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
179181

180182
def forward(

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
247247
self.o_proj = nn.Linear(
248248
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
249249
)
250+
self.rotary_fn = apply_rotary_pos_emb
250251
self.use_qk_norm = config.use_qk_norm
251252
if self.use_qk_norm:
252253
# When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads

src/transformers/models/csm/modeling_csm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ...activations import ACT2FN
3333
from ...cache_utils import Cache, DynamicCache
3434
from ...generation import GenerationMixin
35-
from ...integrations import use_kernel_forward_from_hub
35+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
3636
from ...masking_utils import create_causal_mask
3737
from ...modeling_layers import GradientCheckpointingLayer
3838
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -206,6 +206,7 @@ def rotate_half(x):
206206
return torch.cat((-x2, x1), dim=-1)
207207

208208

209+
@use_kernel_func_from_hub("rotary_pos_emb")
209210
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
210211
"""Applies Rotary Position Embedding to the query and key tensors.
211212
@@ -296,6 +297,7 @@ def __init__(self, config: CsmConfig, layer_idx: int):
296297
self.o_proj = nn.Linear(
297298
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
298299
)
300+
self.rotary_fn = apply_rotary_pos_emb
299301

300302
def forward(
301303
self,

src/transformers/models/cwm/modeling_cwm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ...activations import ACT2FN
2929
from ...cache_utils import Cache, DynamicCache
3030
from ...generation import GenerationMixin
31-
from ...integrations import use_kernel_forward_from_hub
31+
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
3232
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
3333
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3434
from ...modeling_layers import GradientCheckpointingLayer
@@ -113,6 +113,7 @@ def rotate_half(x):
113113
return torch.cat((-x2, x1), dim=-1)
114114

115115

116+
@use_kernel_func_from_hub("rotary_pos_emb")
116117
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
117118
"""Applies Rotary Position Embedding to the query and key tensors.
118119
@@ -195,6 +196,7 @@ def __init__(self, config: CwmConfig, layer_idx: int):
195196
self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
196197
self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
197198
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
199+
self.rotary_fn = apply_rotary_pos_emb
198200
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
199201

200202
def forward(

0 commit comments

Comments
 (0)