1111import math
1212import inspect
1313import warnings
14+ from functools import cache
1415from typing import Optional , Union , Any
1516
1617from ._typing import Array , Device , Namespace
1718
1819
19- def _is_jax_zero_gradient_array (x : object ) -> bool :
20+ @cache
21+ def _issubclass_fast (cls : type , modname : str , clsname : str ) -> bool :
22+ try :
23+ mod = sys .modules [modname ]
24+ except KeyError :
25+ return False
26+ parent_cls = getattr (mod , clsname )
27+ return issubclass (cls , parent_cls )
28+
29+
30+ def _is_jax_zero_gradient_array (x : Array ) -> bool :
2031 """Return True if `x` is a zero-gradient array.
2132
2233 These arrays are a design quirk of Jax that may one day be removed.
2334 See https://github.com/google/jax/issues/20620.
2435 """
25- if 'numpy' not in sys .modules or 'jax' not in sys .modules :
36+ # Fast exit
37+ try :
38+ dtype = x .dtype
39+ except AttributeError :
40+ return False
41+ if not _issubclass_fast (type (dtype ), "numpy.dtypes" , "VoidDType" ):
2642 return False
2743
28- import numpy as np
29- import jax
44+ if "jax" not in sys . modules :
45+ return False
3046
31- return isinstance (x , np .ndarray ) and x .dtype == jax .float0
47+ import jax
48+ # jax.float0 is a np.dtype([('float0', 'V')])
49+ return dtype == jax .float0
3250
3351
3452def is_numpy_array (x : object ) -> bool :
@@ -52,15 +70,12 @@ def is_numpy_array(x: object) -> bool:
5270 is_jax_array
5371 is_pydata_sparse_array
5472 """
55- # Avoid importing NumPy if it isn't already
56- if 'numpy' not in sys .modules :
57- return False
58-
59- import numpy as np
60-
6173 # TODO: Should we reject ndarray subclasses?
62- return (isinstance (x , (np .ndarray , np .generic ))
63- and not _is_jax_zero_gradient_array (x ))
74+ cls = type (x )
75+ return (
76+ _issubclass_fast (cls , "numpy" , "ndarray" )
77+ or _issubclass_fast (cls , "numpy" , "generic" )
78+ ) and not _is_jax_zero_gradient_array (x )
6479
6580
6681def is_cupy_array (x : object ) -> bool :
@@ -84,14 +99,7 @@ def is_cupy_array(x: object) -> bool:
8499 is_jax_array
85100 is_pydata_sparse_array
86101 """
87- # Avoid importing CuPy if it isn't already
88- if 'cupy' not in sys .modules :
89- return False
90-
91- import cupy as cp
92-
93- # TODO: Should we reject ndarray subclasses?
94- return isinstance (x , cp .ndarray )
102+ return _issubclass_fast (type (x ), "cupy" , "ndarray" )
95103
96104
97105def is_torch_array (x : object ) -> bool :
@@ -112,14 +120,7 @@ def is_torch_array(x: object) -> bool:
112120 is_jax_array
113121 is_pydata_sparse_array
114122 """
115- # Avoid importing torch if it isn't already
116- if 'torch' not in sys .modules :
117- return False
118-
119- import torch
120-
121- # TODO: Should we reject ndarray subclasses?
122- return isinstance (x , torch .Tensor )
123+ return _issubclass_fast (type (x ), "torch" , "Tensor" )
123124
124125
125126def is_ndonnx_array (x : object ) -> bool :
@@ -141,13 +142,7 @@ def is_ndonnx_array(x: object) -> bool:
141142 is_jax_array
142143 is_pydata_sparse_array
143144 """
144- # Avoid importing torch if it isn't already
145- if 'ndonnx' not in sys .modules :
146- return False
147-
148- import ndonnx as ndx
149-
150- return isinstance (x , ndx .Array )
145+ return _issubclass_fast (type (x ), "ndonnx" , "Array" )
151146
152147
153148def is_dask_array (x : object ) -> bool :
@@ -169,13 +164,7 @@ def is_dask_array(x: object) -> bool:
169164 is_jax_array
170165 is_pydata_sparse_array
171166 """
172- # Avoid importing dask if it isn't already
173- if 'dask.array' not in sys .modules :
174- return False
175-
176- import dask .array
177-
178- return isinstance (x , dask .array .Array )
167+ return _issubclass_fast (type (x ), "dask.array" , "Array" )
179168
180169
181170def is_jax_array (x : object ) -> bool :
@@ -198,13 +187,7 @@ def is_jax_array(x: object) -> bool:
198187 is_dask_array
199188 is_pydata_sparse_array
200189 """
201- # Avoid importing jax if it isn't already
202- if 'jax' not in sys .modules :
203- return False
204-
205- import jax
206-
207- return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
190+ return _issubclass_fast (type (x ), "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
208191
209192
210193def is_pydata_sparse_array (x ) -> bool :
@@ -227,14 +210,8 @@ def is_pydata_sparse_array(x) -> bool:
227210 is_dask_array
228211 is_jax_array
229212 """
230- # Avoid importing jax if it isn't already
231- if 'sparse' not in sys .modules :
232- return False
233-
234- import sparse
235-
236213 # TODO: Account for other backends.
237- return isinstance ( x , sparse . SparseArray )
214+ return _issubclass_fast ( type ( x ), " sparse" , " SparseArray" )
238215
239216
240217def is_array_api_obj (x : object ) -> bool :
@@ -252,20 +229,30 @@ def is_array_api_obj(x: object) -> bool:
252229 is_dask_array
253230 is_jax_array
254231 """
255- return is_numpy_array (x ) \
256- or is_cupy_array (x ) \
257- or is_torch_array (x ) \
258- or is_dask_array (x ) \
259- or is_jax_array (x ) \
260- or is_pydata_sparse_array (x ) \
261- or hasattr (x , '__array_namespace__' )
232+ return hasattr (x , '__array_namespace__' ) or _is_array_api_cls (type (x ))
233+
234+
235+ @cache
236+ def _is_array_api_cls (cls : type ) -> bool :
237+ return (
238+ # TODO: drop support for numpy<2 which didn't have __array_namespace__
239+ _issubclass_fast (cls , "numpy" , "ndarray" )
240+ or _issubclass_fast (cls , "numpy" , "generic" )
241+ or _issubclass_fast (cls , "cupy" , "ndarray" )
242+ or _issubclass_fast (cls , "torch" , "Tensor" )
243+ or _issubclass_fast (cls , "dask.array" , "Array" )
244+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
245+ # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
246+ or _issubclass_fast (cls , "jax" , "Array" )
247+ )
262248
263249
264250def _compat_module_name () -> str :
265251 assert __name__ .endswith ('.common._helpers' )
266252 return __name__ .removesuffix ('.common._helpers' )
267253
268254
255+ @cache
269256def is_numpy_namespace (xp : Namespace ) -> bool :
270257 """
271258 Returns True if `xp` is a NumPy namespace.
@@ -287,6 +274,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
287274 return xp .__name__ in {'numpy' , _compat_module_name () + '.numpy' }
288275
289276
277+ @cache
290278def is_cupy_namespace (xp : Namespace ) -> bool :
291279 """
292280 Returns True if `xp` is a CuPy namespace.
@@ -308,6 +296,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
308296 return xp .__name__ in {'cupy' , _compat_module_name () + '.cupy' }
309297
310298
299+ @cache
311300def is_torch_namespace (xp : Namespace ) -> bool :
312301 """
313302 Returns True if `xp` is a PyTorch namespace.
@@ -348,6 +337,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
348337 return xp .__name__ == 'ndonnx'
349338
350339
340+ @cache
351341def is_dask_namespace (xp : Namespace ) -> bool :
352342 """
353343 Returns True if `xp` is a Dask namespace.
@@ -952,4 +942,4 @@ def is_lazy_array(x: object) -> bool:
952942 "to_device" ,
953943]
954944
955- _all_ignore = ['sys' , 'math' , 'inspect' , 'warnings' ]
945+ _all_ignore = ['cache' , ' sys' , 'math' , 'inspect' , 'warnings' ]
0 commit comments