We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent eb7e95e commit 9c0ea9cCopy full SHA for 9c0ea9c
array_api_compat/common/_helpers.py
@@ -229,10 +229,11 @@ def is_array_api_obj(x: object) -> bool:
229
is_dask_array
230
is_jax_array
231
"""
232
- if hasattr(x, '__array_namespace__'):
233
- return True
+ return hasattr(x, '__array_namespace__') or _is_array_api_cls(type(x))
234
235
- cls = type(x)
+
+@cache
236
+def _is_array_api_cls(cls: type) -> bool:
237
return (
238
# TODO: drop support for numpy<2 which didn't have __array_namespace__
239
_issubclass_fast(cls, "numpy", "ndarray")
0 commit comments