1919 data )
2020from ndindex import iter_indices
2121
22+ import math
2223import itertools
24+ from typing import Tuple
2325
2426from .array_helpers import assert_exactly_equal , asarray
2527from .hypothesis_helpers import (arrays , all_floating_dtypes , xps , shapes ,
2628 kwargs , matrix_shapes , square_matrix_shapes ,
27- symmetric_matrices ,
29+ symmetric_matrices , SearchStrategy ,
2830 positive_definite_matrices , MAX_ARRAY_SIZE ,
2931 invertible_matrices , two_mutual_arrays ,
3032 mutually_promotable_dtypes , one_d_shapes ,
3537from . import dtype_helpers as dh
3638from . import pytest_helpers as ph
3739from . import shape_helpers as sh
40+ from .typing import Array
3841
3942from . import _array_module
4043from . import _array_module as xp
@@ -589,7 +592,7 @@ def test_slogdet(x):
589592 # TODO: Test this when we have tests for floating-point values.
590593 # assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps)
591594
592- def solve_args ():
595+ def solve_args () -> Tuple [ SearchStrategy [ Array ], SearchStrategy [ Array ]] :
593596 """
594597 Strategy for the x1 and x2 arguments to test_solve()
595598
@@ -608,8 +611,9 @@ def solve_args():
608611
609612 @composite
610613 def _x2_shapes (draw ):
611- end = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ))
612- return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + (end ,)
614+ base_shape = draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :]
615+ end = draw (integers (0 , SQRT_MAX_ARRAY_SIZE // max (math .prod (base_shape ), 1 )))
616+ return base_shape + (end ,)
613617
614618 x2_shapes = one_of (x1 .map (lambda x : (x .shape [- 1 ],)), _x2_shapes ())
615619 x2 = arrays (shape = x2_shapes , dtype = mutual_dtypes .map (lambda pair : pair [1 ]))
0 commit comments