1111import math
1212import inspect
1313import warnings
14+ from functools import cache
1415from typing import Optional , Union , Any
1516
1617from ._typing import Array , Device , Namespace
1718
1819
19- def _is_jax_zero_gradient_array (x : object ) -> bool :
20+ @cache
21+ def _issubclass_fast (cls : type , modname : str , clsname : str ) -> bool :
22+ try :
23+ mod = sys .modules [modname ]
24+ except KeyError :
25+ return False
26+ parent_cls = getattr (mod , clsname )
27+ return issubclass (cls , parent_cls )
28+
29+
30+ def _is_jax_zero_gradient_array (x : Array ) -> bool :
2031 """Return True if `x` is a zero-gradient array.
2132
2233 These arrays are a design quirk of Jax that may one day be removed.
2334 See https://github.com/google/jax/issues/20620.
2435 """
25- if 'numpy' not in sys .modules or 'jax' not in sys .modules :
36+ # Fast exit
37+ try :
38+ dtype = x .dtype
39+ except AttributeError :
40+ return False
41+ if not _issubclass_fast (type (dtype ), "numpy.dtypes" , "VoidDType" ):
2642 return False
2743
28- import numpy as np
29- import jax
44+ if "jax" not in sys . modules :
45+ return False
3046
31- return isinstance (x , np .ndarray ) and x .dtype == jax .float0
47+ import jax
48+ # jax.float0 is a np.dtype([('float0', 'V')])
49+ return dtype == jax .float0
3250
3351
3452def is_numpy_array (x : object ) -> bool :
@@ -52,15 +70,12 @@ def is_numpy_array(x: object) -> bool:
5270 is_jax_array
5371 is_pydata_sparse_array
5472 """
55- # Avoid importing NumPy if it isn't already
56- if 'numpy' not in sys .modules :
57- return False
58-
59- import numpy as np
60-
6173 # TODO: Should we reject ndarray subclasses?
62- return (isinstance (x , (np .ndarray , np .generic ))
63- and not _is_jax_zero_gradient_array (x ))
74+ cls = type (x )
75+ return (
76+ _issubclass_fast (cls , "numpy" , "ndarray" )
77+ or _issubclass_fast (cls , "numpy" , "generic" )
78+ ) and not _is_jax_zero_gradient_array (x )
6479
6580
6681def is_cupy_array (x : object ) -> bool :
@@ -84,14 +99,7 @@ def is_cupy_array(x: object) -> bool:
8499 is_jax_array
85100 is_pydata_sparse_array
86101 """
87- # Avoid importing CuPy if it isn't already
88- if 'cupy' not in sys .modules :
89- return False
90-
91- import cupy as cp
92-
93- # TODO: Should we reject ndarray subclasses?
94- return isinstance (x , cp .ndarray )
102+ return _issubclass_fast (type (x ), "cupy" , "ndarray" )
95103
96104
97105def is_torch_array (x : object ) -> bool :
@@ -112,14 +120,7 @@ def is_torch_array(x: object) -> bool:
112120 is_jax_array
113121 is_pydata_sparse_array
114122 """
115- # Avoid importing torch if it isn't already
116- if 'torch' not in sys .modules :
117- return False
118-
119- import torch
120-
121- # TODO: Should we reject ndarray subclasses?
122- return isinstance (x , torch .Tensor )
123+ return _issubclass_fast (type (x ), "torch" , "Tensor" )
123124
124125
125126def is_ndonnx_array (x : object ) -> bool :
@@ -141,13 +142,7 @@ def is_ndonnx_array(x: object) -> bool:
141142 is_jax_array
142143 is_pydata_sparse_array
143144 """
144- # Avoid importing torch if it isn't already
145- if 'ndonnx' not in sys .modules :
146- return False
147-
148- import ndonnx as ndx
149-
150- return isinstance (x , ndx .Array )
145+ return _issubclass_fast (type (x ), "ndonnx" , "Array" )
151146
152147
153148def is_dask_array (x : object ) -> bool :
@@ -169,13 +164,7 @@ def is_dask_array(x: object) -> bool:
169164 is_jax_array
170165 is_pydata_sparse_array
171166 """
172- # Avoid importing dask if it isn't already
173- if 'dask.array' not in sys .modules :
174- return False
175-
176- import dask .array
177-
178- return isinstance (x , dask .array .Array )
167+ return _issubclass_fast (type (x ), "dask.array" , "Array" )
179168
180169
181170def is_jax_array (x : object ) -> bool :
@@ -198,13 +187,7 @@ def is_jax_array(x: object) -> bool:
198187 is_dask_array
199188 is_pydata_sparse_array
200189 """
201- # Avoid importing jax if it isn't already
202- if 'jax' not in sys .modules :
203- return False
204-
205- import jax
206-
207- return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
190+ return _issubclass_fast (type (x ), "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
208191
209192
210193def is_pydata_sparse_array (x ) -> bool :
@@ -227,14 +210,8 @@ def is_pydata_sparse_array(x) -> bool:
227210 is_dask_array
228211 is_jax_array
229212 """
230- # Avoid importing jax if it isn't already
231- if 'sparse' not in sys .modules :
232- return False
233-
234- import sparse
235-
236213 # TODO: Account for other backends.
237- return isinstance ( x , sparse . SparseArray )
214+ return _issubclass_fast ( type ( x ), " sparse" , " SparseArray" )
238215
239216
240217def is_array_api_obj (x : object ) -> bool :
@@ -252,20 +229,29 @@ def is_array_api_obj(x: object) -> bool:
252229 is_dask_array
253230 is_jax_array
254231 """
255- return is_numpy_array (x ) \
256- or is_cupy_array (x ) \
257- or is_torch_array (x ) \
258- or is_dask_array (x ) \
259- or is_jax_array (x ) \
260- or is_pydata_sparse_array (x ) \
261- or hasattr (x , '__array_namespace__' )
232+ if hasattr (x , '__array_namespace__' ):
233+ return True
234+
235+ cls = type (x )
236+ return (
237+ # TODO: drop support for numpy<2 which didn't have __array_namespace__
238+ _issubclass_fast (cls , "numpy" , "ndarray" )
239+ or _issubclass_fast (cls , "numpy" , "generic" )
240+ or _issubclass_fast (cls , "cupy" , "ndarray" )
241+ or _issubclass_fast (cls , "torch" , "Tensor" )
242+ or _issubclass_fast (cls , "dask.array" , "Array" )
243+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
244+ # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
245+ or _issubclass_fast (cls , "jax" , "Array" )
246+ )
262247
263248
264249def _compat_module_name () -> str :
265250 assert __name__ .endswith ('.common._helpers' )
266251 return __name__ .removesuffix ('.common._helpers' )
267252
268253
254+ @cache
269255def is_numpy_namespace (xp : Namespace ) -> bool :
270256 """
271257 Returns True if `xp` is a NumPy namespace.
@@ -287,6 +273,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
287273 return xp .__name__ in {'numpy' , _compat_module_name () + '.numpy' }
288274
289275
276+ @cache
290277def is_cupy_namespace (xp : Namespace ) -> bool :
291278 """
292279 Returns True if `xp` is a CuPy namespace.
@@ -308,6 +295,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
308295 return xp .__name__ in {'cupy' , _compat_module_name () + '.cupy' }
309296
310297
298+ @cache
311299def is_torch_namespace (xp : Namespace ) -> bool :
312300 """
313301 Returns True if `xp` is a PyTorch namespace.
@@ -348,6 +336,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
348336 return xp .__name__ == 'ndonnx'
349337
350338
339+ @cache
351340def is_dask_namespace (xp : Namespace ) -> bool :
352341 """
353342 Returns True if `xp` is a Dask namespace.
@@ -952,4 +941,4 @@ def is_lazy_array(x: object) -> bool:
952941 "to_device" ,
953942]
954943
955- _all_ignore = ['sys' , 'math' , 'inspect' , 'warnings' ]
944+ _all_ignore = ['cache' , ' sys' , 'math' , 'inspect' , 'warnings' ]
0 commit comments