1818import inspect
1919import warnings
2020
21- def _is_jax_zero_gradient_array (x ) :
21+ def _is_jax_zero_gradient_array (x : object ) -> bool :
2222 """Return True if `x` is a zero-gradient array.
2323
2424 These arrays are a design quirk of Jax that may one day be removed.
@@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):
3232
3333 return isinstance (x , np .ndarray ) and x .dtype == jax .float0
3434
35- def is_numpy_array (x ):
35+
36+ def is_numpy_array (x : object ) -> bool :
3637 """
3738 Return True if `x` is a NumPy array.
3839
@@ -63,7 +64,8 @@ def is_numpy_array(x):
6364 return (isinstance (x , (np .ndarray , np .generic ))
6465 and not _is_jax_zero_gradient_array (x ))
6566
66- def is_cupy_array (x ):
67+
68+ def is_cupy_array (x : object ) -> bool :
6769 """
6870 Return True if `x` is a CuPy array.
6971
@@ -93,7 +95,8 @@ def is_cupy_array(x):
9395 # TODO: Should we reject ndarray subclasses?
9496 return isinstance (x , cp .ndarray )
9597
96- def is_torch_array (x ):
98+
99+ def is_torch_array (x : object ) -> bool :
97100 """
98101 Return True if `x` is a PyTorch tensor.
99102
@@ -120,7 +123,8 @@ def is_torch_array(x):
120123 # TODO: Should we reject ndarray subclasses?
121124 return isinstance (x , torch .Tensor )
122125
123- def is_ndonnx_array (x ):
126+
127+ def is_ndonnx_array (x : object ) -> bool :
124128 """
125129 Return True if `x` is a ndonnx Array.
126130
@@ -147,7 +151,8 @@ def is_ndonnx_array(x):
147151
148152 return isinstance (x , ndx .Array )
149153
150- def is_dask_array (x ):
154+
155+ def is_dask_array (x : object ) -> bool :
151156 """
152157 Return True if `x` is a dask.array Array.
153158
@@ -174,7 +179,8 @@ def is_dask_array(x):
174179
175180 return isinstance (x , dask .array .Array )
176181
177- def is_jax_array (x ):
182+
183+ def is_jax_array (x : object ) -> bool :
178184 """
179185 Return True if `x` is a JAX array.
180186
@@ -202,6 +208,7 @@ def is_jax_array(x):
202208
203209 return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
204210
211+
205212def is_pydata_sparse_array (x ) -> bool :
206213 """
207214 Return True if `x` is an array from the `sparse` package.
@@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
231238 # TODO: Account for other backends.
232239 return isinstance (x , sparse .SparseArray )
233240
234- def is_array_api_obj (x ):
241+
242+ def is_array_api_obj (x : object ) -> bool :
235243 """
236244 Return True if `x` is an array API compatible array object.
237245
@@ -254,10 +262,12 @@ def is_array_api_obj(x):
254262 or is_pydata_sparse_array (x ) \
255263 or hasattr (x , '__array_namespace__' )
256264
257- def _compat_module_name ():
265+
266+ def _compat_module_name () -> str :
258267 assert __name__ .endswith ('.common._helpers' )
259268 return __name__ .removesuffix ('.common._helpers' )
260269
270+
261271def is_numpy_namespace (xp ) -> bool :
262272 """
263273 Returns True if `xp` is a NumPy namespace.
@@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
278288 """
279289 return xp .__name__ in {'numpy' , _compat_module_name () + '.numpy' }
280290
291+
281292def is_cupy_namespace (xp ) -> bool :
282293 """
283294 Returns True if `xp` is a CuPy namespace.
@@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
298309 """
299310 return xp .__name__ in {'cupy' , _compat_module_name () + '.cupy' }
300311
312+
301313def is_torch_namespace (xp ) -> bool :
302314 """
303315 Returns True if `xp` is a PyTorch namespace.
@@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
319331 return xp .__name__ in {'torch' , _compat_module_name () + '.torch' }
320332
321333
322- def is_ndonnx_namespace (xp ):
334+ def is_ndonnx_namespace (xp ) -> bool :
323335 """
324336 Returns True if `xp` is an NDONNX namespace.
325337
@@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
337349 """
338350 return xp .__name__ == 'ndonnx'
339351
340- def is_dask_namespace (xp ):
352+
353+ def is_dask_namespace (xp ) -> bool :
341354 """
342355 Returns True if `xp` is a Dask namespace.
343356
@@ -357,7 +370,8 @@ def is_dask_namespace(xp):
357370 """
358371 return xp .__name__ in {'dask.array' , _compat_module_name () + '.dask.array' }
359372
360- def is_jax_namespace (xp ):
373+
374+ def is_jax_namespace (xp ) -> bool :
361375 """
362376 Returns True if `xp` is a JAX namespace.
363377
@@ -378,7 +392,8 @@ def is_jax_namespace(xp):
378392 """
379393 return xp .__name__ in {'jax.numpy' , 'jax.experimental.array_api' }
380394
381- def is_pydata_sparse_namespace (xp ):
395+
396+ def is_pydata_sparse_namespace (xp ) -> bool :
382397 """
383398 Returns True if `xp` is a pydata/sparse namespace.
384399
@@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
396411 """
397412 return xp .__name__ == 'sparse'
398413
399- def is_array_api_strict_namespace (xp ):
414+
415+ def is_array_api_strict_namespace (xp ) -> bool :
400416 """
401417 Returns True if `xp` is an array-api-strict namespace.
402418
@@ -414,13 +430,15 @@ def is_array_api_strict_namespace(xp):
414430 """
415431 return xp .__name__ == 'array_api_strict'
416432
417- def _check_api_version (api_version ):
433+
434+ def _check_api_version (api_version : str ) -> None :
418435 if api_version in ['2021.12' , '2022.12' ]:
419436 warnings .warn (f"The { api_version } version of the array API specification was requested but the returned namespace is actually version 2023.12" )
420437 elif api_version is not None and api_version not in ['2021.12' , '2022.12' ,
421438 '2023.12' ]:
422439 raise ValueError ("Only the 2023.12 version of the array API specification is currently supported" )
423440
441+
424442def array_namespace (* xs , api_version = None , use_compat = None ):
425443 """
426444 Get the array API compatible namespace for the arrays `xs`.
@@ -631,13 +649,9 @@ def device(x: Array, /) -> Device:
631649 return "cpu"
632650 elif is_dask_array (x ):
633651 # Peek at the metadata of the jax array to determine type
634- try :
635- import numpy as np
636- if isinstance (x ._meta , np .ndarray ):
637- # Must be on CPU since backed by numpy
638- return "cpu"
639- except ImportError :
640- pass
652+ if is_numpy_array (x ._meta ):
653+ # Must be on CPU since backed by numpy
654+ return "cpu"
641655 return _DASK_DEVICE
642656 elif is_jax_array (x ):
643657 # JAX has .device() as a method, but it is being deprecated so that it
@@ -788,24 +802,30 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788802 return x .to_device (device , stream = stream )
789803
790804
791- def size (x ) :
805+ def size (x : Array ) -> int | None :
792806 """
793807 Return the total number of elements of x.
794808
795809 This is equivalent to `x.size` according to the `standard
796810 <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
811+
797812 This helper is included because PyTorch defines `size` in an
798813 :external+torch:meth:`incompatible way <torch.Tensor.size>`.
799-
814+ It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
815+ the standard requires None.
800816 """
817+ # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
801818 if None in x .shape :
802819 return None
803- return math .prod (x .shape )
820+ out = math .prod (x .shape )
821+ # dask.array.Array.shape can contain NaN
822+ return None if math .isnan (out ) else out
804823
805824
806- def is_writeable_array (x ) -> bool :
825+ def is_writeable_array (x : object ) -> bool :
807826 """
808827 Return False if ``x.__setitem__`` is expected to raise; True otherwise.
828+ Return False if `x` is not an array API compatible object.
809829
810830 Warning
811831 -------
@@ -816,7 +836,67 @@ def is_writeable_array(x) -> bool:
816836 return x .flags .writeable
817837 if is_jax_array (x ) or is_pydata_sparse_array (x ):
818838 return False
819- return True
839+ return is_array_api_obj (x )
840+
841+
842+ def is_lazy_array (x : object ) -> bool :
843+ """Return True if x is potentially a future or it may be otherwise impossible or
844+ expensive to eagerly read its contents, regardless of their size, e.g. by
845+ calling ``bool(x)`` or ``float(x)``.
846+
847+ Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
848+ cheap as long as the array has the right dtype and size.
849+
850+ Note
851+ ----
852+ This function errs on the side of caution for array types that may or may not be
853+ lazy, e.g. JAX arrays, by always returning True for them.
854+ """
855+ if (
856+ is_numpy_array (x )
857+ or is_cupy_array (x )
858+ or is_torch_array (x )
859+ or is_pydata_sparse_array (x )
860+ ):
861+ return False
862+
863+ # **JAX note:** while it is possible to determine if you're inside or outside
864+ # jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
865+ # as we do below for unknown arrays, this is not recommended by JAX best practices.
866+
867+ # **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
868+ # This behaviour, while impossible to change without breaking backwards
869+ # compatibility, is highly detrimental to performance as the whole graph will end
870+ # up being computed multiple times.
871+
872+ if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
873+ return True
874+
875+ if not is_array_api_obj (x ):
876+ return False
877+
878+ # Unknown Array API compatible object. Note that this test may have dire consequences
879+ # in terms of performance, e.g. for a lazy object that eagerly computes the graph
880+ # on __bool__ (dask is one such example, which however is special-cased above).
881+
882+ # Select a single point of the array
883+ s = size (x )
884+ if s is None :
885+ return True
886+ xp = array_namespace (x )
887+ if s > 1 :
888+ x = xp .reshape (x , (- 1 ,))[0 ]
889+ # Cast to dtype=bool and deal with size 0 arrays
890+ x = xp .any (x )
891+
892+ try :
893+ bool (x )
894+ return False
895+ # The Array API standard dictactes that __bool__ should raise TypeError if the
896+ # output cannot be defined.
897+ # Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
898+ except Exception :
899+ return True
820900
821901
822902__all__ = [
@@ -840,6 +920,7 @@ def is_writeable_array(x) -> bool:
840920 "is_pydata_sparse_array" ,
841921 "is_pydata_sparse_namespace" ,
842922 "is_writeable_array" ,
923+ "is_lazy_array" ,
843924 "size" ,
844925 "to_device" ,
845926]
0 commit comments