19
19
data )
20
20
from ndindex import iter_indices
21
21
22
+ import math
22
23
import itertools
24
+ from typing import Tuple
23
25
24
26
from .array_helpers import assert_exactly_equal , asarray
25
27
from .hypothesis_helpers import (arrays , all_floating_dtypes , xps , shapes ,
26
28
kwargs , matrix_shapes , square_matrix_shapes ,
27
- symmetric_matrices ,
29
+ symmetric_matrices , SearchStrategy ,
28
30
positive_definite_matrices , MAX_ARRAY_SIZE ,
29
31
invertible_matrices , two_mutual_arrays ,
30
32
mutually_promotable_dtypes , one_d_shapes ,
35
37
from . import dtype_helpers as dh
36
38
from . import pytest_helpers as ph
37
39
from . import shape_helpers as sh
40
+ from .typing import Array
38
41
39
42
from . import _array_module
40
43
from . import _array_module as xp
@@ -589,7 +592,7 @@ def test_slogdet(x):
589
592
# TODO: Test this when we have tests for floating-point values.
590
593
# assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps)
591
594
592
- def solve_args ():
595
+ def solve_args () -> Tuple [ SearchStrategy [ Array ], SearchStrategy [ Array ]] :
593
596
"""
594
597
Strategy for the x1 and x2 arguments to test_solve()
595
598
@@ -608,8 +611,9 @@ def solve_args():
608
611
609
612
@composite
610
613
def _x2_shapes (draw ):
611
- end = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ))
612
- return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + (end ,)
614
+ base_shape = draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :]
615
+ end = draw (integers (0 , SQRT_MAX_ARRAY_SIZE // max (math .prod (base_shape ), 1 )))
616
+ return base_shape + (end ,)
613
617
614
618
x2_shapes = one_of (x1 .map (lambda x : (x .shape [- 1 ],)), _x2_shapes ())
615
619
x2 = arrays (shape = x2_shapes , dtype = mutual_dtypes .map (lambda pair : pair [1 ]))
0 commit comments