|
5 | 5 | is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, |
6 | 6 | ) |
7 | 7 |
|
8 | | -from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device |
9 | | - |
| 8 | +from array_api_compat import ( |
| 9 | + device, is_array_api_obj, is_lazy_array, is_writeable_array, to_device |
| 10 | +) |
10 | 11 | from ._helpers import import_, wrapped_libraries, all_libraries |
11 | 12 |
|
12 | 13 | import pytest |
@@ -92,6 +93,70 @@ def test_is_writeable_array_numpy(): |
92 | 93 | assert not is_writeable_array(x) |
93 | 94 |
|
94 | 95 |
|
| 96 | +@pytest.mark.parametrize("library", all_libraries) |
| 97 | +def test_is_lazy_array(library): |
| 98 | + lib = import_(library) |
| 99 | + x = lib.asarray([1, 2, 3]) |
| 100 | + assert isinstance(is_lazy_array(x), bool) |
| 101 | + |
| 102 | + |
| 103 | +@pytest.mark.parametrize("array", [ |
| 104 | + [], [1, 2], 1, 0, float("nan"), [[1, 2], [3, 4]] |
| 105 | +]) |
| 106 | +def test_is_lazy_array_unknown(array, monkeypatch): |
| 107 | + """Test is_lazy_array() on an unknown Array API compliant object""" |
| 108 | + xp = import_("jax.numpy") |
| 109 | + import array_api_compat.common._helpers |
| 110 | + import jax |
| 111 | + |
| 112 | + x = xp.asarray(array) |
| 113 | + # Prevent is_lazy_array() from special-casing JAX |
| 114 | + monkeypatch.setattr( |
| 115 | + array_api_compat.common._helpers, |
| 116 | + "is_jax_array", |
| 117 | + lambda x: False, |
| 118 | + ) |
| 119 | + |
| 120 | + assert not is_lazy_array(x) # Eager JAX |
| 121 | + assert jax.jit(is_lazy_array)(x) # Jitted (lazy) JAX |
| 122 | + |
| 123 | + |
| 124 | +def test_is_lazy_array_unknown_dask(monkeypatch): |
| 125 | + """Test is_lazy_array() on an unknown Array API compliant object which |
| 126 | + - may or may not raise an arbitrary exception on bool() |
| 127 | + - may or may not have NaN in its shape |
| 128 | + """ |
| 129 | + da = import_("dask.array", wrapper=True) |
| 130 | + import array_api_compat.common._helpers |
| 131 | + |
| 132 | + x = da.arange(10) |
| 133 | + y = x[x > 5] |
| 134 | + assert np.isnan(y.size) |
| 135 | + |
| 136 | + def do_not_run(_): |
| 137 | + raise AssertionError("do_not_run") |
| 138 | + |
| 139 | + z = x.map_blocks(do_not_run, dtype=x.dtype) |
| 140 | + with pytest.raises(AssertionError, match="do_not_run"): |
| 141 | + z.compute() |
| 142 | + |
| 143 | + # Prevent is_lazy_array() from special-casing Dask |
| 144 | + monkeypatch.setattr( |
| 145 | + array_api_compat.common._helpers, |
| 146 | + "is_dask_array", |
| 147 | + lambda x: False, |
| 148 | + ) |
| 149 | + monkeypatch.setattr( |
| 150 | + array_api_compat.common._helpers, |
| 151 | + "array_namespace", |
| 152 | + lambda x: da, |
| 153 | + ) |
| 154 | + |
| 155 | + assert not is_lazy_array(x) # Eagerly computes on bool() |
| 156 | + assert is_lazy_array(y) # NaN size |
| 157 | + assert is_lazy_array(z) # bool() raises AssertionError |
| 158 | + |
| 159 | + |
95 | 160 | @pytest.mark.parametrize("library", all_libraries) |
96 | 161 | def test_device(library): |
97 | 162 | xp = import_(library, wrapper=True) |
@@ -149,6 +214,7 @@ def test_asarray_cross_library(source_library, target_library, request): |
149 | 214 |
|
150 | 215 | assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" |
151 | 216 |
|
| 217 | + |
152 | 218 | @pytest.mark.parametrize("library", wrapped_libraries) |
153 | 219 | def test_asarray_copy(library): |
154 | 220 | # Note, we have this test here because the test suite currently doesn't |
|
0 commit comments