Skip to content

Commit e807ffe

Browse files
committed
ENH: add "repro snippets" to test_searching_functions.py
1 parent 8096c93 commit e807ffe

File tree

1 file changed

+142
-112
lines changed

1 file changed

+142
-112
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 142 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,28 @@ def test_argmax(x, data):
3333
)
3434
keepdims = kw.get("keepdims", False)
3535

36-
out = xp.argmax(x, **kw)
36+
repro_snippet = ph.format_snippet(f"xp.argmax({x!r}, **kw) with {kw = }")
37+
try:
38+
out = xp.argmax(x, **kw)
3739

38-
ph.assert_default_index("argmax", out.dtype)
39-
axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
40-
ph.assert_keepdimable_shape(
41-
"argmax", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
42-
)
43-
scalar_type = dh.get_scalar_type(x.dtype)
44-
for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
45-
max_i = int(out[out_idx])
46-
elements = []
47-
for idx in indices:
48-
s = scalar_type(x[idx])
49-
elements.append(s)
50-
expected = max(range(len(elements)), key=elements.__getitem__)
51-
ph.assert_scalar_equals("argmax", type_=int, idx=out_idx, out=max_i,
52-
expected=expected, kw=kw)
40+
ph.assert_default_index("argmax", out.dtype)
41+
axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
42+
ph.assert_keepdimable_shape(
43+
"argmax", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
44+
)
45+
scalar_type = dh.get_scalar_type(x.dtype)
46+
for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
47+
max_i = int(out[out_idx])
48+
elements = []
49+
for idx in indices:
50+
s = scalar_type(x[idx])
51+
elements.append(s)
52+
expected = max(range(len(elements)), key=elements.__getitem__)
53+
ph.assert_scalar_equals("argmax", type_=int, idx=out_idx, out=max_i,
54+
expected=expected, kw=kw)
55+
except Exception as exc:
56+
exc.add_note(repro_snippet)
57+
raise
5358

5459

5560
@given(
@@ -70,22 +75,27 @@ def test_argmin(x, data):
7075
)
7176
keepdims = kw.get("keepdims", False)
7277

73-
out = xp.argmin(x, **kw)
78+
repro_snippet = ph.format_snippet(f"xp.argmin({x!r}, **kw) with {kw = }")
79+
try:
80+
out = xp.argmin(x, **kw)
7481

75-
ph.assert_default_index("argmin", out.dtype)
76-
axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
77-
ph.assert_keepdimable_shape(
78-
"argmin", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
79-
)
80-
scalar_type = dh.get_scalar_type(x.dtype)
81-
for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
82-
min_i = int(out[out_idx])
83-
elements = []
84-
for idx in indices:
85-
s = scalar_type(x[idx])
86-
elements.append(s)
87-
expected = min(range(len(elements)), key=elements.__getitem__)
88-
ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected)
82+
ph.assert_default_index("argmin", out.dtype)
83+
axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
84+
ph.assert_keepdimable_shape(
85+
"argmin", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
86+
)
87+
scalar_type = dh.get_scalar_type(x.dtype)
88+
for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
89+
min_i = int(out[out_idx])
90+
elements = []
91+
for idx in indices:
92+
s = scalar_type(x[idx])
93+
elements.append(s)
94+
expected = min(range(len(elements)), key=elements.__getitem__)
95+
ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected)
96+
except Exception as exc:
97+
exc.add_note(repro_snippet)
98+
raise
8999

90100

91101
# XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on
@@ -115,23 +125,28 @@ def test_count_nonzero(x, data):
115125

116126
assume(kw.get("axis", None) != ()) # TODO clarify in the spec
117127

118-
out = xp.count_nonzero(x, **kw)
128+
repro_snippet = ph.format_snippet(f"xp.count_nonzero({x!r}, **kw) with {kw = }")
129+
try:
130+
out = xp.count_nonzero(x, **kw)
119131

120-
ph.assert_default_index("count_nonzero", out.dtype)
121-
axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
122-
ph.assert_keepdimable_shape(
123-
"count_nonzero", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
124-
)
125-
scalar_type = dh.get_scalar_type(x.dtype)
132+
ph.assert_default_index("count_nonzero", out.dtype)
133+
axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
134+
ph.assert_keepdimable_shape(
135+
"count_nonzero", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
136+
)
137+
scalar_type = dh.get_scalar_type(x.dtype)
126138

127-
for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
128-
count = int(out[out_idx])
129-
elements = []
130-
for idx in indices:
131-
s = scalar_type(x[idx])
132-
elements.append(s)
133-
expected = sum(el != 0 for el in elements)
134-
ph.assert_scalar_equals("count_nonzero", type_=int, idx=out_idx, out=count, expected=expected)
139+
for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
140+
count = int(out[out_idx])
141+
elements = []
142+
for idx in indices:
143+
s = scalar_type(x[idx])
144+
elements.append(s)
145+
expected = sum(el != 0 for el in elements)
146+
ph.assert_scalar_equals("count_nonzero", type_=int, idx=out_idx, out=count, expected=expected)
147+
except Exception as exc:
148+
exc.add_note(repro_snippet)
149+
raise
135150

136151

137152
@given(hh.arrays(dtype=hh.all_dtypes, shape=()))
@@ -143,39 +158,44 @@ def test_nonzero_zerodim_error(x):
143158
@pytest.mark.data_dependent_shapes
144159
@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1, min_side=1)))
145160
def test_nonzero(x):
146-
out = xp.nonzero(x)
147-
assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}"
148-
out_size = math.prod(out[0].shape)
149-
for i in range(len(out)):
150-
assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1"
151-
size_at = math.prod(out[i].shape)
152-
assert size_at == out_size, (
153-
f"prod(out[{i}].shape)={size_at}, "
154-
f"but should be prod(out[0].shape)={out_size}"
155-
)
156-
ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype")
157-
indices = []
158-
if x.dtype == xp.bool:
159-
for idx in sh.ndindex(x.shape):
160-
if x[idx]:
161-
indices.append(idx)
162-
else:
163-
for idx in sh.ndindex(x.shape):
164-
if x[idx] != 0:
165-
indices.append(idx)
166-
if x.ndim == 0:
167-
assert out_size == len(
168-
indices
169-
), f"prod(out[0].shape)={out_size}, but should be {len(indices)}"
170-
else:
171-
for i in range(out_size):
172-
idx = tuple(int(x[i]) for x in out)
173-
f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}"
174-
f_element = f"x[{idx}]={x[idx]}"
175-
assert idx in indices, f"{f_idx} results in {f_element}, a zero element"
176-
assert (
177-
idx == indices[i]
178-
), f"{f_idx} is in the wrong position, should be {indices.index(idx)}"
161+
repro_snippet = ph.format_snippet(f"xp.nonzero({x!r})")
162+
try:
163+
out = xp.nonzero(x)
164+
assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}"
165+
out_size = math.prod(out[0].shape)
166+
for i in range(len(out)):
167+
assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1"
168+
size_at = math.prod(out[i].shape)
169+
assert size_at == out_size, (
170+
f"prod(out[{i}].shape)={size_at}, "
171+
f"but should be prod(out[0].shape)={out_size}"
172+
)
173+
ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype")
174+
indices = []
175+
if x.dtype == xp.bool:
176+
for idx in sh.ndindex(x.shape):
177+
if x[idx]:
178+
indices.append(idx)
179+
else:
180+
for idx in sh.ndindex(x.shape):
181+
if x[idx] != 0:
182+
indices.append(idx)
183+
if x.ndim == 0:
184+
assert out_size == len(
185+
indices
186+
), f"prod(out[0].shape)={out_size}, but should be {len(indices)}"
187+
else:
188+
for i in range(out_size):
189+
idx = tuple(int(x[i]) for x in out)
190+
f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}"
191+
f_element = f"x[{idx}]={x[idx]}"
192+
assert idx in indices, f"{f_idx} results in {f_element}, a zero element"
193+
assert (
194+
idx == indices[i]
195+
), f"{f_idx} is in the wrong position, should be {indices.index(idx)}"
196+
except Exception as exc:
197+
exc.add_note(repro_snippet)
198+
raise
179199

180200

181201
@given(
@@ -188,31 +208,36 @@ def test_where(shapes, dtypes, data):
188208
x1 = data.draw(hh.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1")
189209
x2 = data.draw(hh.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2")
190210

191-
out = xp.where(cond, x1, x2)
192-
193-
shape = sh.broadcast_shapes(*shapes)
194-
ph.assert_shape("where", out_shape=out.shape, expected=shape)
195-
# TODO: generate indices without broadcasting arrays
196-
_cond = xp.broadcast_to(cond, shape)
197-
_x1 = xp.broadcast_to(x1, shape)
198-
_x2 = xp.broadcast_to(x2, shape)
199-
for idx in sh.ndindex(shape):
200-
if _cond[idx]:
201-
ph.assert_0d_equals(
202-
"where",
203-
x_repr=f"_x1[{idx}]",
204-
x_val=_x1[idx],
205-
out_repr=f"out[{idx}]",
206-
out_val=out[idx]
207-
)
208-
else:
209-
ph.assert_0d_equals(
210-
"where",
211-
x_repr=f"_x2[{idx}]",
212-
x_val=_x2[idx],
213-
out_repr=f"out[{idx}]",
214-
out_val=out[idx]
215-
)
211+
repro_snippet = ph.format_snippet(f"xp.where({cond!r}, {x1!r}, {x2!r})")
212+
try:
213+
out = xp.where(cond, x1, x2)
214+
215+
shape = sh.broadcast_shapes(*shapes)
216+
ph.assert_shape("where", out_shape=out.shape, expected=shape)
217+
# TODO: generate indices without broadcasting arrays
218+
_cond = xp.broadcast_to(cond, shape)
219+
_x1 = xp.broadcast_to(x1, shape)
220+
_x2 = xp.broadcast_to(x2, shape)
221+
for idx in sh.ndindex(shape):
222+
if _cond[idx]:
223+
ph.assert_0d_equals(
224+
"where",
225+
x_repr=f"_x1[{idx}]",
226+
x_val=_x1[idx],
227+
out_repr=f"out[{idx}]",
228+
out_val=out[idx]
229+
)
230+
else:
231+
ph.assert_0d_equals(
232+
"where",
233+
x_repr=f"_x2[{idx}]",
234+
x_val=_x2[idx],
235+
out_repr=f"out[{idx}]",
236+
out_val=out[idx]
237+
)
238+
except Exception as exc:
239+
exc.add_note(repro_snippet)
240+
raise
216241

217242

218243
@pytest.mark.min_version("2023.12")
@@ -238,12 +263,17 @@ def test_searchsorted(data):
238263
label="x2",
239264
)
240265

241-
out = xp.searchsorted(x1, x2, sorter=sorter)
266+
repro_snippet = ph.format_snippet(f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r})")
267+
try:
268+
out = xp.searchsorted(x1, x2, sorter=sorter)
242269

243-
ph.assert_dtype(
244-
"searchsorted",
245-
in_dtype=[x1.dtype, x2.dtype],
246-
out_dtype=out.dtype,
247-
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
248-
)
249-
# TODO: shapes and values testing
270+
ph.assert_dtype(
271+
"searchsorted",
272+
in_dtype=[x1.dtype, x2.dtype],
273+
out_dtype=out.dtype,
274+
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
275+
)
276+
# TODO: shapes and values testing
277+
except Exception as exc:
278+
exc.add_note(repro_snippet)
279+
raise

0 commit comments

Comments
 (0)