|
10 | 10 | is_dask_array, is_jax_array, is_pydata_sparse_array, |
11 | 11 | is_numpy_namespace, is_cupy_namespace, is_torch_namespace, |
12 | 12 | is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, |
| 13 | + is_array_api_strict_namespace, |
13 | 14 | ) |
14 | 15 |
|
15 | 16 | from array_api_compat import ( |
|
33 | 34 | 'dask.array': 'is_dask_namespace', |
34 | 35 | 'jax.numpy': 'is_jax_namespace', |
35 | 36 | 'sparse': 'is_pydata_sparse_namespace', |
| 37 | + 'array_api_strict': 'is_array_api_strict_namespace', |
36 | 38 | } |
37 | 39 |
|
38 | 40 |
|
@@ -74,7 +76,12 @@ def test_xp_is_array_generics(library): |
74 | 76 | is_func = globals()[func] |
75 | 77 | if is_func(x0): |
76 | 78 | matches.append(library2) |
77 | | - assert matches in ([library], ["numpy"]) |
| 79 | + |
| 80 | + if library == "array_api_strict": |
| 81 | + # There is no is_array_api_strict_array() function |
| 82 | + assert matches == [] |
| 83 | + else: |
| 84 | + assert matches in ([library], ["numpy"]) |
78 | 85 |
|
79 | 86 |
|
80 | 87 | @pytest.mark.parametrize("library", all_libraries) |
@@ -192,26 +199,33 @@ def test_to_device_host(library): |
192 | 199 | @pytest.mark.parametrize("target_library", is_array_functions.keys()) |
193 | 200 | @pytest.mark.parametrize("source_library", is_array_functions.keys()) |
194 | 201 | def test_asarray_cross_library(source_library, target_library, request): |
195 | | - if source_library == "dask.array" and target_library == "torch": |
| 202 | + def _xfail(reason: str) -> None: |
196 | 203 | # Allow rest of test to execute instead of immediately xfailing |
197 | 204 | # xref https://github.com/pandas-dev/pandas/issues/38902 |
| 205 | + request.node.add_marker(pytest.mark.xfail(reason=reason)) |
198 | 206 |
|
| 207 | + if source_library == "dask.array" and target_library == "torch": |
199 | 208 | # TODO: remove xfail once |
200 | 209 | # https://github.com/dask/dask/issues/8260 is resolved |
201 | | - request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion")) |
202 | | - if source_library == "cupy" and target_library != "cupy": |
| 210 | + _xfail(reason="Bug in dask raising error on conversion") |
| 211 | + elif source_library == "jax.numpy" and target_library == "torch": |
| 212 | + _xfail(reason="casts int to float") |
| 213 | + elif source_library == "cupy" and target_library != "cupy": |
203 | 214 | # cupy explicitly disallows implicit conversions to CPU |
204 | 215 | pytest.skip(reason="cupy does not support implicit conversion to CPU") |
205 | 216 | elif source_library == "sparse" and target_library != "sparse": |
206 | 217 | pytest.skip(reason="`sparse` does not allow implicit densification") |
| 218 | + |
207 | 219 | src_lib = import_(source_library, wrapper=True) |
208 | 220 | tgt_lib = import_(target_library, wrapper=True) |
209 | 221 | is_tgt_type = globals()[is_array_functions[target_library]] |
210 | 222 |
|
211 | | - a = src_lib.asarray([1, 2, 3]) |
| 223 | + a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32) |
212 | 224 | b = tgt_lib.asarray(a) |
213 | 225 |
|
214 | 226 | assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" |
| 227 | + assert b.dtype == tgt_lib.int32 |
| 228 | + |
215 | 229 |
|
216 | 230 |
|
217 | 231 | @pytest.mark.parametrize("library", wrapped_libraries) |
|
0 commit comments