@@ -33,23 +33,28 @@ def test_argmax(x, data):
3333    )
3434    keepdims  =  kw .get ("keepdims" , False )
3535
36-     out  =  xp .argmax (x , ** kw )
36+     repro_snippet  =  ph .format_snippet (f"xp.argmax({ x !r} { kw  =  }  )
37+     try :
38+         out  =  xp .argmax (x , ** kw )
3739
38-     ph .assert_default_index ("argmax" , out .dtype )
39-     axes  =  sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
40-     ph .assert_keepdimable_shape (
41-         "argmax" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw 
42-     )
43-     scalar_type  =  dh .get_scalar_type (x .dtype )
44-     for  indices , out_idx  in  zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
45-         max_i  =  int (out [out_idx ])
46-         elements  =  []
47-         for  idx  in  indices :
48-             s  =  scalar_type (x [idx ])
49-             elements .append (s )
50-         expected  =  max (range (len (elements )), key = elements .__getitem__ )
51-         ph .assert_scalar_equals ("argmax" , type_ = int , idx = out_idx , out = max_i ,
52-                                 expected = expected , kw = kw )
40+         ph .assert_default_index ("argmax" , out .dtype )
41+         axes  =  sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
42+         ph .assert_keepdimable_shape (
43+             "argmax" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw 
44+         )
45+         scalar_type  =  dh .get_scalar_type (x .dtype )
46+         for  indices , out_idx  in  zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
47+             max_i  =  int (out [out_idx ])
48+             elements  =  []
49+             for  idx  in  indices :
50+                 s  =  scalar_type (x [idx ])
51+                 elements .append (s )
52+             expected  =  max (range (len (elements )), key = elements .__getitem__ )
53+             ph .assert_scalar_equals ("argmax" , type_ = int , idx = out_idx , out = max_i ,
54+                                     expected = expected , kw = kw )
55+     except  Exception  as  exc :
56+         exc .add_note (repro_snippet )
57+         raise 
5358
5459
5560@given ( 
@@ -70,22 +75,27 @@ def test_argmin(x, data):
7075    )
7176    keepdims  =  kw .get ("keepdims" , False )
7277
73-     out  =  xp .argmin (x , ** kw )
78+     repro_snippet  =  ph .format_snippet (f"xp.argmin({ x !r} { kw  =  }  )
79+     try :
80+         out  =  xp .argmin (x , ** kw )
7481
75-     ph .assert_default_index ("argmin" , out .dtype )
76-     axes  =  sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
77-     ph .assert_keepdimable_shape (
78-         "argmin" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw 
79-     )
80-     scalar_type  =  dh .get_scalar_type (x .dtype )
81-     for  indices , out_idx  in  zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
82-         min_i  =  int (out [out_idx ])
83-         elements  =  []
84-         for  idx  in  indices :
85-             s  =  scalar_type (x [idx ])
86-             elements .append (s )
87-         expected  =  min (range (len (elements )), key = elements .__getitem__ )
88-         ph .assert_scalar_equals ("argmin" , type_ = int , idx = out_idx , out = min_i , expected = expected )
82+         ph .assert_default_index ("argmin" , out .dtype )
83+         axes  =  sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
84+         ph .assert_keepdimable_shape (
85+             "argmin" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw 
86+         )
87+         scalar_type  =  dh .get_scalar_type (x .dtype )
88+         for  indices , out_idx  in  zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
89+             min_i  =  int (out [out_idx ])
90+             elements  =  []
91+             for  idx  in  indices :
92+                 s  =  scalar_type (x [idx ])
93+                 elements .append (s )
94+             expected  =  min (range (len (elements )), key = elements .__getitem__ )
95+             ph .assert_scalar_equals ("argmin" , type_ = int , idx = out_idx , out = min_i , expected = expected )
96+     except  Exception  as  exc :
97+         exc .add_note (repro_snippet )
98+         raise 
8999
90100
91101# XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on 
@@ -115,23 +125,28 @@ def test_count_nonzero(x, data):
115125
116126    assume (kw .get ("axis" , None ) !=  ())  # TODO clarify in the spec 
117127
118-     out  =  xp .count_nonzero (x , ** kw )
128+     repro_snippet  =  ph .format_snippet (f"xp.count_nonzero({ x !r} { kw  =  }  )
129+     try :
130+         out  =  xp .count_nonzero (x , ** kw )
119131
120-     ph .assert_default_index ("count_nonzero" , out .dtype )
121-     axes  =  sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
122-     ph .assert_keepdimable_shape (
123-         "count_nonzero" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw 
124-     )
125-     scalar_type  =  dh .get_scalar_type (x .dtype )
132+          ph .assert_default_index ("count_nonzero" , out .dtype )
133+          axes  =  sh .normalize_axis (kw .get ("axis" , None ), x .ndim )
134+          ph .assert_keepdimable_shape (
135+              "count_nonzero" , in_shape = x .shape , out_shape = out .shape , axes = axes , keepdims = keepdims , kw = kw 
136+          )
137+          scalar_type  =  dh .get_scalar_type (x .dtype )
126138
127-     for  indices , out_idx  in  zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
128-         count  =  int (out [out_idx ])
129-         elements  =  []
130-         for  idx  in  indices :
131-             s  =  scalar_type (x [idx ])
132-             elements .append (s )
133-         expected  =  sum (el  !=  0  for  el  in  elements )
134-         ph .assert_scalar_equals ("count_nonzero" , type_ = int , idx = out_idx , out = count , expected = expected )
139+         for  indices , out_idx  in  zip (sh .axes_ndindex (x .shape , axes ), sh .ndindex (out .shape )):
140+             count  =  int (out [out_idx ])
141+             elements  =  []
142+             for  idx  in  indices :
143+                 s  =  scalar_type (x [idx ])
144+                 elements .append (s )
145+             expected  =  sum (el  !=  0  for  el  in  elements )
146+             ph .assert_scalar_equals ("count_nonzero" , type_ = int , idx = out_idx , out = count , expected = expected )
147+     except  Exception  as  exc :
148+         exc .add_note (repro_snippet )
149+         raise 
135150
136151
137152@given (hh .arrays (dtype = hh .all_dtypes , shape = ())) 
@@ -143,39 +158,44 @@ def test_nonzero_zerodim_error(x):
143158@pytest .mark .data_dependent_shapes  
144159@given (hh .arrays (dtype = hh .all_dtypes , shape = hh .shapes (min_dims = 1 , min_side = 1 ))) 
145160def  test_nonzero (x ):
146-     out  =  xp .nonzero (x )
147-     assert  len (out ) ==  x .ndim , f"{ len (out )= } { x .ndim = }  
148-     out_size  =  math .prod (out [0 ].shape )
149-     for  i  in  range (len (out )):
150-         assert  out [i ].ndim  ==  1 , f"out[{ i } { x .ndim }  
151-         size_at  =  math .prod (out [i ].shape )
152-         assert  size_at  ==  out_size , (
153-             f"prod(out[{ i } { size_at }  
154-             f"but should be prod(out[0].shape)={ out_size }  
155-         )
156-         ph .assert_default_index ("nonzero" , out [i ].dtype , repr_name = f"out[{ i }  )
157-     indices  =  []
158-     if  x .dtype  ==  xp .bool :
159-         for  idx  in  sh .ndindex (x .shape ):
160-             if  x [idx ]:
161-                 indices .append (idx )
162-     else :
163-         for  idx  in  sh .ndindex (x .shape ):
164-             if  x [idx ] !=  0 :
165-                 indices .append (idx )
166-     if  x .ndim  ==  0 :
167-         assert  out_size  ==  len (
168-             indices 
169-         ), f"prod(out[0].shape)={ out_size } { len (indices )}  
170-     else :
171-         for  i  in  range (out_size ):
172-             idx  =  tuple (int (x [i ]) for  x  in  out )
173-             f_idx  =  f"Extrapolated index (x[{ i } { idx }  
174-             f_element  =  f"x[{ idx } { x [idx ]}  
175-             assert  idx  in  indices , f"{ f_idx } { f_element }  
176-             assert  (
177-                 idx  ==  indices [i ]
178-             ), f"{ f_idx } { indices .index (idx )}  
161+     repro_snippet  =  ph .format_snippet (f"xp.nonzero({ x !r}  )
162+     try :
163+         out  =  xp .nonzero (x )
164+         assert  len (out ) ==  x .ndim , f"{ len (out )= } { x .ndim = }  
165+         out_size  =  math .prod (out [0 ].shape )
166+         for  i  in  range (len (out )):
167+             assert  out [i ].ndim  ==  1 , f"out[{ i } { x .ndim }  
168+             size_at  =  math .prod (out [i ].shape )
169+             assert  size_at  ==  out_size , (
170+                 f"prod(out[{ i } { size_at }  
171+                 f"but should be prod(out[0].shape)={ out_size }  
172+             )
173+             ph .assert_default_index ("nonzero" , out [i ].dtype , repr_name = f"out[{ i }  )
174+         indices  =  []
175+         if  x .dtype  ==  xp .bool :
176+             for  idx  in  sh .ndindex (x .shape ):
177+                 if  x [idx ]:
178+                     indices .append (idx )
179+         else :
180+             for  idx  in  sh .ndindex (x .shape ):
181+                 if  x [idx ] !=  0 :
182+                     indices .append (idx )
183+         if  x .ndim  ==  0 :
184+             assert  out_size  ==  len (
185+                 indices 
186+             ), f"prod(out[0].shape)={ out_size } { len (indices )}  
187+         else :
188+             for  i  in  range (out_size ):
189+                 idx  =  tuple (int (x [i ]) for  x  in  out )
190+                 f_idx  =  f"Extrapolated index (x[{ i } { idx }  
191+                 f_element  =  f"x[{ idx } { x [idx ]}  
192+                 assert  idx  in  indices , f"{ f_idx } { f_element }  
193+                 assert  (
194+                     idx  ==  indices [i ]
195+                 ), f"{ f_idx } { indices .index (idx )}  
196+     except  Exception  as  exc :
197+         exc .add_note (repro_snippet )
198+         raise 
179199
180200
181201@given ( 
@@ -188,31 +208,36 @@ def test_where(shapes, dtypes, data):
188208    x1  =  data .draw (hh .arrays (dtype = dtypes [0 ], shape = shapes [1 ]), label = "x1" )
189209    x2  =  data .draw (hh .arrays (dtype = dtypes [1 ], shape = shapes [2 ]), label = "x2" )
190210
191-     out  =  xp .where (cond , x1 , x2 )
192- 
193-     shape  =  sh .broadcast_shapes (* shapes )
194-     ph .assert_shape ("where" , out_shape = out .shape , expected = shape )
195-     # TODO: generate indices without broadcasting arrays 
196-     _cond  =  xp .broadcast_to (cond , shape )
197-     _x1  =  xp .broadcast_to (x1 , shape )
198-     _x2  =  xp .broadcast_to (x2 , shape )
199-     for  idx  in  sh .ndindex (shape ):
200-         if  _cond [idx ]:
201-             ph .assert_0d_equals (
202-                 "where" ,
203-                 x_repr = f"_x1[{ idx }  ,
204-                 x_val = _x1 [idx ],
205-                 out_repr = f"out[{ idx }  ,
206-                 out_val = out [idx ]
207-             )
208-         else :
209-             ph .assert_0d_equals (
210-                 "where" ,
211-                 x_repr = f"_x2[{ idx }  ,
212-                 x_val = _x2 [idx ],
213-                 out_repr = f"out[{ idx }  ,
214-                 out_val = out [idx ]
215-             )
211+     repro_snippet  =  ph .format_snippet (f"xp.where({ cond !r} { x1 !r} { x2 !r}  )
212+     try :
213+         out  =  xp .where (cond , x1 , x2 )
214+ 
215+         shape  =  sh .broadcast_shapes (* shapes )
216+         ph .assert_shape ("where" , out_shape = out .shape , expected = shape )
217+         # TODO: generate indices without broadcasting arrays 
218+         _cond  =  xp .broadcast_to (cond , shape )
219+         _x1  =  xp .broadcast_to (x1 , shape )
220+         _x2  =  xp .broadcast_to (x2 , shape )
221+         for  idx  in  sh .ndindex (shape ):
222+             if  _cond [idx ]:
223+                 ph .assert_0d_equals (
224+                     "where" ,
225+                     x_repr = f"_x1[{ idx }  ,
226+                     x_val = _x1 [idx ],
227+                     out_repr = f"out[{ idx }  ,
228+                     out_val = out [idx ]
229+                 )
230+             else :
231+                 ph .assert_0d_equals (
232+                     "where" ,
233+                     x_repr = f"_x2[{ idx }  ,
234+                     x_val = _x2 [idx ],
235+                     out_repr = f"out[{ idx }  ,
236+                     out_val = out [idx ]
237+                 )
238+     except  Exception  as  exc :
239+         exc .add_note (repro_snippet )
240+         raise 
216241
217242
218243@pytest .mark .min_version ("2023.12" ) 
@@ -238,12 +263,17 @@ def test_searchsorted(data):
238263        label = "x2" ,
239264    )
240265
241-     out  =  xp .searchsorted (x1 , x2 , sorter = sorter )
266+     repro_snippet  =  ph .format_snippet (f"xp.searchsorted({ x1 !r} { x2 !r} { sorter !r}  )
267+     try :
268+         out  =  xp .searchsorted (x1 , x2 , sorter = sorter )
242269
243-     ph .assert_dtype (
244-         "searchsorted" ,
245-         in_dtype = [x1 .dtype , x2 .dtype ],
246-         out_dtype = out .dtype ,
247-         expected = xp .__array_namespace_info__ ().default_dtypes ()["indexing" ],
248-     )
249-     # TODO: shapes and values testing 
270+         ph .assert_dtype (
271+             "searchsorted" ,
272+             in_dtype = [x1 .dtype , x2 .dtype ],
273+             out_dtype = out .dtype ,
274+             expected = xp .__array_namespace_info__ ().default_dtypes ()["indexing" ],
275+         )
276+         # TODO: shapes and values testing 
277+     except  Exception  as  exc :
278+         exc .add_note (repro_snippet )
279+         raise 
0 commit comments