@@ -48,9 +48,10 @@ def is_numpy_array(x):
4848 is_array_api_obj
4949 is_cupy_array
5050 is_torch_array
51+ is_ndonnx_array
5152 is_dask_array
5253 is_jax_array
53- is_pydata_sparse
54+ is_pydata_sparse_array
5455 """
5556 # Avoid importing NumPy if it isn't already
5657 if 'numpy' not in sys .modules :
@@ -78,11 +79,12 @@ def is_cupy_array(x):
7879 is_array_api_obj
7980 is_numpy_array
8081 is_torch_array
82+ is_ndonnx_array
8183 is_dask_array
8284 is_jax_array
83- is_pydata_sparse
85+ is_pydata_sparse_array
8486 """
85- # Avoid importing NumPy if it isn't already
87+ # Avoid importing CuPy if it isn't already
8688 if 'cupy' not in sys .modules :
8789 return False
8890
@@ -107,7 +109,7 @@ def is_torch_array(x):
107109 is_cupy_array
108110 is_dask_array
109111 is_jax_array
110- is_pydata_sparse
112+ is_pydata_sparse_array
111113 """
112114 # Avoid importing torch if it isn't already
113115 if 'torch' not in sys .modules :
@@ -118,6 +120,33 @@ def is_torch_array(x):
118120 # TODO: Should we reject ndarray subclasses?
119121 return isinstance (x , torch .Tensor )
120122
123+ def is_ndonnx_array (x ):
124+ """
125+ Return True if `x` is a ndonnx Array.
126+
127+ This function does not import ndonnx if it has not already been imported
128+ and is therefore cheap to use.
129+
130+ See Also
131+ --------
132+
133+ array_namespace
134+ is_array_api_obj
135+ is_numpy_array
136+ is_cupy_array
137+ is_ndonnx_array
138+ is_dask_array
139+ is_jax_array
140+ is_pydata_sparse_array
141+ """
142+ # Avoid importing torch if it isn't already
143+ if 'ndonnx' not in sys .modules :
144+ return False
145+
146+ import ndonnx as ndx
147+
148+ return isinstance (x , ndx .Array )
149+
121150def is_dask_array (x ):
122151 """
123152 Return True if `x` is a dask.array Array.
@@ -133,8 +162,9 @@ def is_dask_array(x):
133162 is_numpy_array
134163 is_cupy_array
135164 is_torch_array
165+ is_ndonnx_array
136166 is_jax_array
137- is_pydata_sparse
167+ is_pydata_sparse_array
138168 """
139169 # Avoid importing dask if it isn't already
140170 if 'dask.array' not in sys .modules :
@@ -160,8 +190,9 @@ def is_jax_array(x):
160190 is_numpy_array
161191 is_cupy_array
162192 is_torch_array
193+ is_ndonnx_array
163194 is_dask_array
164- is_pydata_sparse
195+ is_pydata_sparse_array
165196 """
166197 # Avoid importing jax if it isn't already
167198 if 'jax' not in sys .modules :
@@ -172,7 +203,7 @@ def is_jax_array(x):
172203 return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
173204
174205
175- def is_pydata_sparse (x ) -> bool :
206+ def is_pydata_sparse_array (x ) -> bool :
176207 """
177208 Return True if `x` is an array from the `sparse` package.
178209
@@ -188,6 +219,7 @@ def is_pydata_sparse(x) -> bool:
188219 is_numpy_array
189220 is_cupy_array
190221 is_torch_array
222+ is_ndonnx_array
191223 is_dask_array
192224 is_jax_array
193225 """
@@ -211,6 +243,7 @@ def is_array_api_obj(x):
211243 is_numpy_array
212244 is_cupy_array
213245 is_torch_array
246+ is_ndonnx_array
214247 is_dask_array
215248 is_jax_array
216249 """
@@ -219,7 +252,7 @@ def is_array_api_obj(x):
219252 or is_torch_array (x ) \
220253 or is_dask_array (x ) \
221254 or is_jax_array (x ) \
222- or is_pydata_sparse (x ) \
255+ or is_pydata_sparse_array (x ) \
223256 or hasattr (x , '__array_namespace__' )
224257
225258def _check_api_version (api_version ):
@@ -288,7 +321,7 @@ def your_function(x, y):
288321 is_torch_array
289322 is_dask_array
290323 is_jax_array
291- is_pydata_sparse
324+ is_pydata_sparse_array
292325
293326 """
294327 if use_compat not in [None , True , False ]:
@@ -307,12 +340,9 @@ def your_function(x, y):
307340 elif use_compat is False :
308341 namespaces .add (np )
309342 else :
310- # numpy 2.0 has __array_namespace__ and is fully array API
343+ # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
311344 # compatible.
312- if hasattr (x , '__array_namespace__' ):
313- namespaces .add (x .__array_namespace__ (api_version = api_version ))
314- else :
315- namespaces .add (numpy_namespace )
345+ namespaces .add (numpy_namespace )
316346 elif is_cupy_array (x ):
317347 if _use_compat :
318348 _check_api_version (api_version )
@@ -344,11 +374,15 @@ def your_function(x, y):
344374 elif use_compat is False :
345375 import jax .numpy as jnp
346376 else :
347- # jax.experimental.array_api is already an array namespace. We do
348- # not have a wrapper submodule for it.
349- import jax .experimental .array_api as jnp
377+ # JAX v0.4.32 and newer implements the array API directly in jax.numpy.
378+ # For older JAX versions, it is available via jax.experimental.array_api.
379+ import jax .numpy
380+ if hasattr (jax .numpy , "__array_api_version__" ):
381+ jnp = jax .numpy
382+ else :
383+ import jax .experimental .array_api as jnp
350384 namespaces .add (jnp )
351- elif is_pydata_sparse (x ):
385+ elif is_pydata_sparse_array (x ):
352386 if use_compat is True :
353387 _check_api_version (api_version )
354388 raise ValueError ("`sparse` does not have an array-api-compat wrapper" )
@@ -451,7 +485,7 @@ def device(x: Array, /) -> Device:
451485 return x .device ()
452486 else :
453487 return x .device
454- elif is_pydata_sparse (x ):
488+ elif is_pydata_sparse_array (x ):
455489 # `sparse` will gain `.device`, so check for this first.
456490 x_device = getattr (x , 'device' , None )
457491 if x_device is not None :
@@ -580,10 +614,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
580614 return x
581615 raise ValueError (f"Unsupported device { device !r} " )
582616 elif is_jax_array (x ):
583- # This import adds to_device to x
584- import jax .experimental .array_api # noqa: F401
617+ if not hasattr (x , "__array_namespace__" ):
618+ # In JAX v0.4.31 and older, this import adds to_device method to x.
619+ import jax .experimental .array_api # noqa: F401
585620 return x .to_device (device , stream = stream )
586- elif is_pydata_sparse (x ) and device == _device (x ):
621+ elif is_pydata_sparse_array (x ) and device == _device (x ):
587622 # Perform trivial check to return the same array if
588623 # device is same instead of err-ing.
589624 return x
@@ -613,7 +648,8 @@ def size(x):
613648 "is_jax_array" ,
614649 "is_numpy_array" ,
615650 "is_torch_array" ,
616- "is_pydata_sparse" ,
651+ "is_ndonnx_array" ,
652+ "is_pydata_sparse_array" ,
617653 "size" ,
618654 "to_device" ,
619655]
0 commit comments