3232 register_kernel_mapping ,
3333 replace_kernel_forward_from_hub ,
3434 )
35+ from kernels import (
36+ use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub ,
37+ )
38+
39+ # Try to import FuncRepository, fallback if not available
40+ try :
41+ from kernels import FuncRepository
42+ except ImportError :
43+ FuncRepository = None
44+
45+ # Try to import use_kernel_func_from_hub, fallback if not available
46+ try :
47+ from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub
48+
49+ _has_use_kernel_func_from_hub = True
50+ except ImportError :
51+ _has_use_kernel_func_from_hub = False
3552
3653 _TRANSFORMERS_USE_HUB_KERNELS = os .environ .get ("USE_HUB_KERNELS" , "YES" ).upper ()
3754 _kernels_available = True
3855 _kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES
3956
4057 def use_kernel_forward_from_hub (layer_name : str ):
4158 if _kernels_enabled :
42- from kernels import use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub
43-
4459 return _kernels_use_kernel_forward_from_hub (layer_name )
4560 else :
4661 logger .warning_once (
4762 f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={ _TRANSFORMERS_USE_HUB_KERNELS } "
4863 )
4964 return lambda cls : cls
5065
66+ def use_kernel_func_from_hub (func_name : str ):
67+ if _kernels_enabled and _has_use_kernel_func_from_hub :
68+ return _kernels_use_kernel_func_from_hub (func_name )
69+ else :
70+ if not _has_use_kernel_func_from_hub :
71+ logger .warning_once (
72+ "use_kernel_func_from_hub is not available in the installed kernels version. "
73+ "Please upgrade kernels to use this feature."
74+ )
75+ else :
76+ logger .warning_once (
77+ f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={ _TRANSFORMERS_USE_HUB_KERNELS } "
78+ )
79+ return lambda func : func
80+
5181 _KERNEL_MAPPING : dict [str , dict [Device | str , LayerRepository ]] = {
5282 "MultiScaleDeformableAttention" : {
5383 "cuda" : LayerRepository (
@@ -162,6 +192,16 @@ def use_kernel_forward_from_hub(layer_name: str):
162192 },
163193 }
164194
195+ # Add function kernel mappings if FuncRepository is available
196+ if FuncRepository is not None :
197+ _KERNEL_MAPPING ["rotary_pos_emb" ] = {
198+ "xpu" : {
199+ Mode .INFERENCE : FuncRepository (
200+ repo_id = "kernels-community/rotary" , func_name = "apply_rotary_transformers"
201+ )
202+ }
203+ }
204+
165205 def has_key (d , key ):
166206 return key in d or any (isinstance (v , dict ) and has_key (v , key ) for v in d .values ())
167207
@@ -187,6 +227,12 @@ def decorator(cls):
187227
188228 return decorator
189229
230+ def use_kernel_func_from_hub (* args , ** kwargs ):
231+ def decorator (func ):
232+ return func
233+
234+ return decorator
235+
190236 class LayerRepository :
191237 def __init__ (self , * args , ** kwargs ):
192238 raise RuntimeError ("LayerRepository requires `kernels` to be installed. Run `pip install kernels`." )
@@ -199,6 +245,11 @@ def replace_kernel_forward_from_hub(*args, **kwargs):
199245 def register_kernel_mapping (* args , ** kwargs ):
200246 raise RuntimeError ("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`." )
201247
248+ def register_kernel_mapping_transformers (* args , ** kwargs ):
249+ raise RuntimeError (
250+ "register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`."
251+ )
252+
202253
203254_HUB_KERNEL_MAPPING : dict [str , dict [str , str ]] = {
204255 "causal-conv1d" : {"repo_id" : "kernels-community/causal-conv1d" },
@@ -321,6 +372,7 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _
321372__all__ = [
322373 "LayerRepository" ,
323374 "use_kernel_forward_from_hub" ,
375+ "use_kernel_func_from_hub" ,
324376 "register_kernel_mapping" ,
325377 "register_kernel_mapping_transformers" ,
326378 "replace_kernel_forward_from_hub" ,
0 commit comments