Skip to content

Commit 5811393

Browse files
authored
Merge pull request #242 from honno/solve-buffer-size-fix
Fix `test_solve` generating arrays larger than our max array size
2 parents 8fb5e01 + 695c67e commit 5811393

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

array_api_tests/hypothesis_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,8 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
312312
@composite
313313
def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes()):
314314
# For now, just generate stacks of diagonal matrices.
315-
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
316315
stack_shape = draw(stack_shapes)
316+
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(stack_shape), 1)),)
317317
dtype = draw(dtypes)
318318
elements = one_of(
319319
from_dtype(dtype, min_value=0.5, allow_nan=False, allow_infinity=False),

array_api_tests/test_linalg.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
data)
2020
from ndindex import iter_indices
2121

22+
import math
2223
import itertools
24+
from typing import Tuple
2325

2426
from .array_helpers import assert_exactly_equal, asarray
2527
from .hypothesis_helpers import (arrays, all_floating_dtypes, xps, shapes,
2628
kwargs, matrix_shapes, square_matrix_shapes,
27-
symmetric_matrices,
29+
symmetric_matrices, SearchStrategy,
2830
positive_definite_matrices, MAX_ARRAY_SIZE,
2931
invertible_matrices, two_mutual_arrays,
3032
mutually_promotable_dtypes, one_d_shapes,
@@ -35,6 +37,7 @@
3537
from . import dtype_helpers as dh
3638
from . import pytest_helpers as ph
3739
from . import shape_helpers as sh
40+
from .typing import Array
3841

3942
from . import _array_module
4043
from . import _array_module as xp
@@ -589,7 +592,7 @@ def test_slogdet(x):
589592
# TODO: Test this when we have tests for floating-point values.
590593
# assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps)
591594

592-
def solve_args():
595+
def solve_args() -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]:
593596
"""
594597
Strategy for the x1 and x2 arguments to test_solve()
595598
@@ -608,8 +611,9 @@ def solve_args():
608611

609612
@composite
610613
def _x2_shapes(draw):
611-
end = draw(integers(0, SQRT_MAX_ARRAY_SIZE))
612-
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,)
614+
base_shape = draw(stack_shapes)[1] + draw(x1).shape[-1:]
615+
end = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(base_shape), 1)))
616+
return base_shape + (end,)
613617

614618
x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes())
615619
x2 = arrays(shape=x2_shapes, dtype=mutual_dtypes.map(lambda pair: pair[1]))

0 commit comments

Comments
 (0)