|
1 | 1 | from array_api_compat import ( # noqa: F401 |
2 | 2 | is_numpy_array, is_cupy_array, is_torch_array, |
3 | 3 | is_dask_array, is_jax_array, is_pydata_sparse_array, |
| 4 | + is_ndonnx_array, |
4 | 5 | is_numpy_namespace, is_cupy_namespace, is_torch_namespace, |
5 | 6 | is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, |
6 | | - is_array_api_strict_namespace, |
| 7 | + is_array_api_strict_namespace, is_ndonnx_namespace, |
7 | 8 | ) |
8 | 9 |
|
9 | 10 | from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device |
|
22 | 23 | 'dask.array': 'is_dask_array', |
23 | 24 | 'jax.numpy': 'is_jax_array', |
24 | 25 | 'sparse': 'is_pydata_sparse_array', |
| 26 | + 'ndonnx': 'is_ndonnx_array', |
25 | 27 | } |
26 | 28 |
|
27 | 29 | is_namespace_functions = { |
|
32 | 34 | 'jax.numpy': 'is_jax_namespace', |
33 | 35 | 'sparse': 'is_pydata_sparse_namespace', |
34 | 36 | 'array_api_strict': 'is_array_api_strict_namespace', |
| 37 | + 'ndonnx': 'is_ndonnx_namespace', |
35 | 38 | } |
36 | 39 |
|
37 | 40 |
|
@@ -135,26 +138,40 @@ def test_to_device_host(library): |
135 | 138 | @pytest.mark.parametrize("target_library", is_array_functions.keys()) |
136 | 139 | @pytest.mark.parametrize("source_library", is_array_functions.keys()) |
137 | 140 | def test_asarray_cross_library(source_library, target_library, request): |
138 | | - if source_library == "dask.array" and target_library == "torch": |
| 141 | + def _xfail(reason: str) -> None: |
139 | 142 | # Allow rest of test to execute instead of immediately xfailing |
140 | 143 | # xref https://github.com/pandas-dev/pandas/issues/38902 |
| 144 | + request.node.add_marker(pytest.mark.xfail(reason=reason)) |
141 | 145 |
|
| 146 | + if source_library == "dask.array" and target_library == "torch": |
142 | 147 | # TODO: remove xfail once |
143 | 148 | # https://github.com/dask/dask/issues/8260 is resolved |
144 | | - request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion")) |
145 | | - if source_library == "cupy" and target_library != "cupy": |
| 149 | + _xfail(reason="Bug in dask raising error on conversion") |
| 150 | + elif ( |
| 151 | + source_library == "ndonnx" |
| 152 | + and target_library not in ("array_api_strict", "ndonnx", "numpy") |
| 153 | + ): |
| 154 | + _xfail(reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") |
| 155 | + elif source_library == "ndonnx" and target_library == "numpy": |
| 156 | + _xfail(reason="produces numpy array of ndonnx scalar arrays") |
| 157 | + elif source_library == "jax.numpy" and target_library == "torch": |
| 158 | + _xfail(reason="casts int to float") |
| 159 | + elif source_library == "cupy" and target_library != "cupy": |
146 | 160 | # cupy explicitly disallows implicit conversions to CPU |
147 | 161 | pytest.skip(reason="cupy does not support implicit conversion to CPU") |
148 | 162 | elif source_library == "sparse" and target_library != "sparse": |
149 | 163 | pytest.skip(reason="`sparse` does not allow implicit densification") |
| 164 | + |
150 | 165 | src_lib = import_(source_library, wrapper=True) |
151 | 166 | tgt_lib = import_(target_library, wrapper=True) |
152 | 167 | is_tgt_type = globals()[is_array_functions[target_library]] |
153 | 168 |
|
154 | | - a = src_lib.asarray([1, 2, 3]) |
| 169 | + a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32) |
155 | 170 | b = tgt_lib.asarray(a) |
156 | 171 |
|
157 | 172 | assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" |
| 173 | + assert b.dtype == tgt_lib.int32 |
| 174 | + |
158 | 175 |
|
159 | 176 | @pytest.mark.parametrize("library", wrapped_libraries) |
160 | 177 | def test_asarray_copy(library): |
|
0 commit comments