@@ -65,7 +65,7 @@ def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
6565
6666
6767@wraps (xps .arrays ) 
68- def  arrays (dtype , * args , elements = None , ** kwargs ) ->  SearchStrategy [Array ]:
68+ def  arrays_no_scalars (dtype , * args , elements = None , ** kwargs ) ->  SearchStrategy [Array ]:
6969    """xps.arrays() without the crazy large numbers.""" 
7070    if  isinstance (dtype , SearchStrategy ):
7171        return  dtype .flatmap (lambda  d : arrays (d , * args , elements = elements , ** kwargs ))
@@ -78,6 +78,19 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
7878    return  xps .arrays (dtype , * args , elements = elements , ** kwargs )
7979
8080
81+ def  _f (a , flag ):
82+     return  a [()] if  a .ndim == 0  and  flag  else  a 
83+ 
84+ 
85+ @wraps (xps .arrays ) 
86+ def  arrays (dtype , * args , elements = None , ** kwargs ) ->  SearchStrategy [Array ]:
87+     """xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars. 
88+ 
89+     Is only relevant for numpy: on all other libraries, array[()] is no-op. 
90+     """ 
91+     return  builds (_f , arrays_no_scalars (dtype , * args , elements = elements , ** kwargs ), booleans ())
92+ 
93+ 
8194_dtype_categories  =  [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .real_float_dtypes , dh .complex_dtypes ]
8295_sorted_dtypes  =  [d  for  category  in  _dtype_categories  for  d  in  category ]
8396
0 commit comments