Skip to content

Commit 762f1e7

Browse files
committed
Draw end shape relative to base shape
1 parent b2115b7 commit 762f1e7

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

array_api_tests/test_linalg.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
data)
2020
from ndindex import iter_indices
2121

22+
import math
2223
import itertools
2324
from typing import Tuple
2425

@@ -610,8 +611,9 @@ def solve_args() -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]:
610611

611612
@composite
612613
def _x2_shapes(draw):
613-
end = draw(integers(0, SQRT_MAX_ARRAY_SIZE))
614-
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,)
614+
base_shape = draw(stack_shapes)[1] + draw(x1).shape[-1:]
615+
end = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(base_shape), 1)))
616+
return base_shape + (end,)
615617

616618
x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes())
617619
x2 = arrays(shape=x2_shapes, dtype=mutual_dtypes.map(lambda pair: pair[1]))

0 commit comments

Comments
 (0)