Skip to content

Commit bcfcdba

Browse files
committed
Fix test_reshape
- Fix the input strategy to not generate arrays that are too large, which was causing an error from hypothesis. - Rewrite the reshape_shapes() strategy to generate reshape tuples directly by distributing the prime factors of the array size, rather than by using filtering.
1 parent 5a2036b commit bcfcdba

File tree

2 files changed

+65
-18
lines changed

2 files changed

+65
-18
lines changed

array_api_tests/hypothesis_helpers.py

+62
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,68 @@ def shapes(**kw):
236236
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
237237
)
238238

239+
def _factorize(n: int) -> List[int]:
240+
# Simple prime factorization. Only needs to handle n ~ MAX_ARRAY_SIZE
241+
factors = []
242+
while n % 2 == 0:
243+
factors.append(2)
244+
n //= 2
245+
246+
for i in range(3, int(math.sqrt(n)) + 1, 2):
247+
while n % i == 0:
248+
factors.append(i)
249+
n //= i
250+
251+
if n > 1: # n is a prime number greater than 2
252+
factors.append(n)
253+
254+
return factors
255+
256+
MAX_SIDE = MAX_ARRAY_SIZE // 64
257+
# NumPy only supports up to 32 dims. TODO: Get this from the new inspection APIs
258+
MAX_DIMS = min(MAX_ARRAY_SIZE // MAX_SIDE, 32)
259+
260+
261+
@composite
262+
def reshape_shapes(draw, arr_shape, ndims=integers(1, MAX_DIMS)):
263+
"""
264+
Generate shape tuples whose product equals the product of array_shape.
265+
"""
266+
shape = draw(arr_shape)
267+
268+
array_size = math.prod(shape)
269+
270+
n_dims = draw(ndims)
271+
272+
# Handle special cases
273+
if array_size == 0:
274+
# Generate a random tuple, and ensure at least one of the entries is 0
275+
result = list(draw(shapes(min_dims=n_dims, max_dims=n_dims)))
276+
pos = draw(integers(0, n_dims - 1))
277+
result[pos] = 0
278+
return tuple(result)
279+
280+
if array_size == 1:
281+
return tuple(1 for _ in range(n_dims))
282+
283+
# Get prime factorization
284+
factors = _factorize(array_size)
285+
286+
# Distribute prime factors randomly
287+
result = [1] * n_dims
288+
for factor in factors:
289+
pos = draw(integers(0, n_dims - 1))
290+
result[pos] *= factor
291+
292+
assert math.prod(result) == array_size
293+
294+
# An element of the reshape tuple can be -1, which means it is a stand-in
295+
# for the remaining factors.
296+
if draw(booleans()):
297+
pos = draw(integers(0, n_dims - 1))
298+
result[pos] = -1
299+
300+
return tuple(result)
239301

240302
one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)
241303

array_api_tests/test_manipulation_functions.py

+3-18
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
from . import xps
1515
from .typing import Array, Shape
1616

17-
MAX_SIDE = hh.MAX_ARRAY_SIZE // 64
18-
MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims
19-
2017

2118
def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
2219
key = "shape"
@@ -66,7 +63,7 @@ def test_concat(dtypes, base_shape, data):
6663
shape_strat = hh.shapes()
6764
else:
6865
_axis = axis if axis >= 0 else len(base_shape) + axis
69-
shape_strat = st.integers(0, MAX_SIDE).map(
66+
shape_strat = st.integers(0, hh.MAX_SIDE).map(
7067
lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :]
7168
)
7269
arrays = []
@@ -348,24 +345,12 @@ def test_repeat(x, kw, data):
348345
kw=kw)
349346
start = end
350347

351-
@st.composite
352-
def reshape_shapes(draw, shapes):
353-
shape = draw(shapes)
354-
size = 1 if len(shape) == 0 else math.prod(shape)
355-
rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size))
356-
assume(all(side <= MAX_SIDE for side in rshape))
357-
if len(rshape) != 0 and size > 0 and draw(st.booleans()):
358-
index = draw(st.integers(0, len(rshape) - 1))
359-
rshape[index] = -1
360-
return tuple(rshape)
361-
362-
363-
reshape_shape = st.shared(hh.shapes(max_side=MAX_SIDE), key="reshape_shape")
348+
reshape_shape = st.shared(hh.shapes(), key="reshape_shape")
364349

365350
@pytest.mark.unvectorized
366351
@given(
367352
x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),
368-
shape=reshape_shapes(reshape_shape),
353+
shape=hh.reshape_shapes(reshape_shape),
369354
)
370355
def test_reshape(x, shape):
371356
out = xp.reshape(x, shape)

0 commit comments

Comments
 (0)