1010from typing import TYPE_CHECKING
1111
1212if TYPE_CHECKING :
13+ from types import ModuleType
1314 from typing import Optional , Union , Any
1415 from ._typing import Array , Device
1516
1819import inspect
1920import warnings
2021
21- def _is_jax_zero_gradient_array (x ) :
22+ def _is_jax_zero_gradient_array (x : object ) -> bool :
2223 """Return True if `x` is a zero-gradient array.
2324
2425 These arrays are a design quirk of Jax that may one day be removed.
@@ -32,7 +33,8 @@ def _is_jax_zero_gradient_array(x):
3233
3334 return isinstance (x , np .ndarray ) and x .dtype == jax .float0
3435
35- def is_numpy_array (x ):
36+
37+ def is_numpy_array (x : object ) -> bool :
3638 """
3739 Return True if `x` is a NumPy array.
3840
@@ -63,7 +65,8 @@ def is_numpy_array(x):
6365 return (isinstance (x , (np .ndarray , np .generic ))
6466 and not _is_jax_zero_gradient_array (x ))
6567
66- def is_cupy_array (x ):
68+
69+ def is_cupy_array (x : object ) -> bool :
6770 """
6871 Return True if `x` is a CuPy array.
6972
@@ -93,7 +96,8 @@ def is_cupy_array(x):
9396 # TODO: Should we reject ndarray subclasses?
9497 return isinstance (x , cp .ndarray )
9598
96- def is_torch_array (x ):
99+
100+ def is_torch_array (x : object ) -> bool :
97101 """
98102 Return True if `x` is a PyTorch tensor.
99103
@@ -120,7 +124,8 @@ def is_torch_array(x):
120124 # TODO: Should we reject ndarray subclasses?
121125 return isinstance (x , torch .Tensor )
122126
123- def is_ndonnx_array (x ):
127+
128+ def is_ndonnx_array (x : object ) -> bool :
124129 """
125130 Return True if `x` is a ndonnx Array.
126131
@@ -147,7 +152,8 @@ def is_ndonnx_array(x):
147152
148153 return isinstance (x , ndx .Array )
149154
150- def is_dask_array (x ):
155+
156+ def is_dask_array (x : object ) -> bool :
151157 """
152158 Return True if `x` is a dask.array Array.
153159
@@ -174,7 +180,8 @@ def is_dask_array(x):
174180
175181 return isinstance (x , dask .array .Array )
176182
177- def is_jax_array (x ):
183+
184+ def is_jax_array (x : object ) -> bool :
178185 """
179186 Return True if `x` is a JAX array.
180187
@@ -202,6 +209,7 @@ def is_jax_array(x):
202209
203210 return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
204211
212+
205213def is_pydata_sparse_array (x ) -> bool :
206214 """
207215 Return True if `x` is an array from the `sparse` package.
@@ -231,7 +239,8 @@ def is_pydata_sparse_array(x) -> bool:
231239 # TODO: Account for other backends.
232240 return isinstance (x , sparse .SparseArray )
233241
234- def is_array_api_obj (x ):
242+
243+ def is_array_api_obj (x : object ) -> bool :
235244 """
236245 Return True if `x` is an array API compatible array object.
237246
@@ -254,11 +263,13 @@ def is_array_api_obj(x):
254263 or is_pydata_sparse_array (x ) \
255264 or hasattr (x , '__array_namespace__' )
256265
257- def _compat_module_name ():
266+
267+ def _compat_module_name () -> str :
258268 assert __name__ .endswith ('.common._helpers' )
259269 return __name__ .removesuffix ('.common._helpers' )
260270
261- def is_numpy_namespace (xp ) -> bool :
271+
272+ def is_numpy_namespace (xp : ModuleType ) -> bool :
262273 """
263274 Returns True if `xp` is a NumPy namespace.
264275
@@ -278,7 +289,8 @@ def is_numpy_namespace(xp) -> bool:
278289 """
279290 return xp .__name__ in {'numpy' , _compat_module_name () + '.numpy' }
280291
281- def is_cupy_namespace (xp ) -> bool :
292+
293+ def is_cupy_namespace (xp : ModuleType ) -> bool :
282294 """
283295 Returns True if `xp` is a CuPy namespace.
284296
@@ -298,7 +310,8 @@ def is_cupy_namespace(xp) -> bool:
298310 """
299311 return xp .__name__ in {'cupy' , _compat_module_name () + '.cupy' }
300312
301- def is_torch_namespace (xp ) -> bool :
313+
314+ def is_torch_namespace (xp : ModuleType ) -> bool :
302315 """
303316 Returns True if `xp` is a PyTorch namespace.
304317
@@ -319,7 +332,7 @@ def is_torch_namespace(xp) -> bool:
319332 return xp .__name__ in {'torch' , _compat_module_name () + '.torch' }
320333
321334
322- def is_ndonnx_namespace (xp ) :
335+ def is_ndonnx_namespace (xp : ModuleType ) -> bool :
323336 """
324337 Returns True if `xp` is an NDONNX namespace.
325338
@@ -337,7 +350,8 @@ def is_ndonnx_namespace(xp):
337350 """
338351 return xp .__name__ == 'ndonnx'
339352
340- def is_dask_namespace (xp ):
353+
354+ def is_dask_namespace (xp : ModuleType ) -> bool :
341355 """
342356 Returns True if `xp` is a Dask namespace.
343357
@@ -357,7 +371,8 @@ def is_dask_namespace(xp):
357371 """
358372 return xp .__name__ in {'dask.array' , _compat_module_name () + '.dask.array' }
359373
360- def is_jax_namespace (xp ):
374+
375+ def is_jax_namespace (xp : ModuleType ) -> bool :
361376 """
362377 Returns True if `xp` is a JAX namespace.
363378
@@ -378,7 +393,8 @@ def is_jax_namespace(xp):
378393 """
379394 return xp .__name__ in {'jax.numpy' , 'jax.experimental.array_api' }
380395
381- def is_pydata_sparse_namespace (xp ):
396+
397+ def is_pydata_sparse_namespace (xp : ModuleType ) -> bool :
382398 """
383399 Returns True if `xp` is a pydata/sparse namespace.
384400
@@ -396,7 +412,8 @@ def is_pydata_sparse_namespace(xp):
396412 """
397413 return xp .__name__ == 'sparse'
398414
399- def is_array_api_strict_namespace (xp ):
415+
416+ def is_array_api_strict_namespace (xp : ModuleType ) -> bool :
400417 """
401418 Returns True if `xp` is an array-api-strict namespace.
402419
@@ -414,13 +431,15 @@ def is_array_api_strict_namespace(xp):
414431 """
415432 return xp .__name__ == 'array_api_strict'
416433
417- def _check_api_version (api_version ):
434+
435+ def _check_api_version (api_version : str ) -> None :
418436 if api_version in ['2021.12' , '2022.12' ]:
419437 warnings .warn (f"The { api_version } version of the array API specification was requested but the returned namespace is actually version 2023.12" )
420438 elif api_version is not None and api_version not in ['2021.12' , '2022.12' ,
421439 '2023.12' ]:
422440 raise ValueError ("Only the 2023.12 version of the array API specification is currently supported" )
423441
442+
424443def array_namespace (* xs , api_version = None , use_compat = None ):
425444 """
426445 Get the array API compatible namespace for the arrays `xs`.
@@ -808,9 +827,10 @@ def size(x: Array) -> int | None:
808827 return None if math .isnan (out ) else out
809828
810829
811- def is_writeable_array (x ) -> bool :
830+ def is_writeable_array (x : object ) -> bool :
812831 """
813832 Return False if ``x.__setitem__`` is expected to raise; True otherwise.
833+ Return False if `x` is not an array API compatible object.
814834
815835 Warning
816836 -------
@@ -821,10 +841,10 @@ def is_writeable_array(x) -> bool:
821841 return x .flags .writeable
822842 if is_jax_array (x ) or is_pydata_sparse_array (x ):
823843 return False
824- return True
844+ return is_array_api_obj ( x )
825845
826846
827- def is_lazy_array (x ) -> bool :
847+ def is_lazy_array (x : object ) -> bool :
828848 """Return True if x is potentially a future or it may be otherwise impossible or
829849 expensive to eagerly read its contents, regardless of their size, e.g. by
830850 calling ``bool(x)`` or ``float(x)``.
@@ -857,6 +877,9 @@ def is_lazy_array(x) -> bool:
857877 if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
858878 return True
859879
880+ if not is_array_api_obj (x ):
881+ return False
882+
860883 # Unknown Array API compatible object. Note that this test may have dire consequences
861884 # in terms of performance, e.g. for a lazy object that eagerly computes the graph
862885 # on __bool__ (dask is one such example, which however is special-cased above).
0 commit comments