Skip to content

Commit c2e010e

Browse files
authored
Merge pull request #319 from asmeurer/reshape-fix
Fix test_reshape
2 parents ad81cf6 + bcfcdba commit c2e010e

File tree

2 files changed

+67
-20
lines changed

2 files changed

+67
-20
lines changed

array_api_tests/hypothesis_helpers.py

+62
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,68 @@ def shapes(**kw):
245245
lambda shape: math.prod(i for i in shape if i) < MAX_ARRAY_SIZE
246246
)
247247

248+
def _factorize(n: int) -> List[int]:
249+
# Simple prime factorization. Only needs to handle n ~ MAX_ARRAY_SIZE
250+
factors = []
251+
while n % 2 == 0:
252+
factors.append(2)
253+
n //= 2
254+
255+
for i in range(3, int(math.sqrt(n)) + 1, 2):
256+
while n % i == 0:
257+
factors.append(i)
258+
n //= i
259+
260+
if n > 1: # n is a prime number greater than 2
261+
factors.append(n)
262+
263+
return factors
264+
265+
MAX_SIDE = MAX_ARRAY_SIZE // 64
266+
# NumPy only supports up to 32 dims. TODO: Get this from the new inspection APIs
267+
MAX_DIMS = min(MAX_ARRAY_SIZE // MAX_SIDE, 32)
268+
269+
270+
@composite
271+
def reshape_shapes(draw, arr_shape, ndims=integers(1, MAX_DIMS)):
272+
"""
273+
Generate shape tuples whose product equals the product of array_shape.
274+
"""
275+
shape = draw(arr_shape)
276+
277+
array_size = math.prod(shape)
278+
279+
n_dims = draw(ndims)
280+
281+
# Handle special cases
282+
if array_size == 0:
283+
# Generate a random tuple, and ensure at least one of the entries is 0
284+
result = list(draw(shapes(min_dims=n_dims, max_dims=n_dims)))
285+
pos = draw(integers(0, n_dims - 1))
286+
result[pos] = 0
287+
return tuple(result)
288+
289+
if array_size == 1:
290+
return tuple(1 for _ in range(n_dims))
291+
292+
# Get prime factorization
293+
factors = _factorize(array_size)
294+
295+
# Distribute prime factors randomly
296+
result = [1] * n_dims
297+
for factor in factors:
298+
pos = draw(integers(0, n_dims - 1))
299+
result[pos] *= factor
300+
301+
assert math.prod(result) == array_size
302+
303+
# An element of the reshape tuple can be -1, which means it is a stand-in
304+
# for the remaining factors.
305+
if draw(booleans()):
306+
pos = draw(integers(0, n_dims - 1))
307+
result[pos] = -1
308+
309+
return tuple(result)
248310

249311
one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)
250312

array_api_tests/test_manipulation_functions.py

+5-20
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
from . import xps
1515
from .typing import Array, Shape
1616

17-
MAX_SIDE = hh.MAX_ARRAY_SIZE // 64
18-
MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims
19-
2017

2118
def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
2219
key = "shape"
@@ -66,7 +63,7 @@ def test_concat(dtypes, base_shape, data):
6663
shape_strat = hh.shapes()
6764
else:
6865
_axis = axis if axis >= 0 else len(base_shape) + axis
69-
shape_strat = st.integers(0, MAX_SIDE).map(
66+
shape_strat = st.integers(0, hh.MAX_SIDE).map(
7067
lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :]
7168
)
7269
arrays = []
@@ -348,26 +345,14 @@ def test_repeat(x, kw, data):
348345
kw=kw)
349346
start = end
350347

351-
@st.composite
352-
def reshape_shapes(draw, shape):
353-
size = 1 if len(shape) == 0 else math.prod(shape)
354-
rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size))
355-
assume(all(side <= MAX_SIDE for side in rshape))
356-
if len(rshape) != 0 and size > 0 and draw(st.booleans()):
357-
index = draw(st.integers(0, len(rshape) - 1))
358-
rshape[index] = -1
359-
return tuple(rshape)
360-
348+
reshape_shape = st.shared(hh.shapes(), key="reshape_shape")
361349

362350
@pytest.mark.unvectorized
363-
@pytest.mark.skip("flaky") # TODO: fix!
364351
@given(
365-
x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(max_side=MAX_SIDE)),
366-
data=st.data(),
352+
x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),
353+
shape=hh.reshape_shapes(reshape_shape),
367354
)
368-
def test_reshape(x, data):
369-
shape = data.draw(reshape_shapes(x.shape))
370-
355+
def test_reshape(x, shape):
371356
out = xp.reshape(x, shape)
372357

373358
ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype)

0 commit comments

Comments
 (0)