Skip to content

Commit 5a2036b

Browse files
committed
Use shared() in test_reshape
1 parent 4bbe6be commit 5a2036b

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

Diff for: array_api_tests/test_manipulation_functions.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ def test_repeat(x, kw, data):
349349
start = end
350350

351351
@st.composite
352-
def reshape_shapes(draw, shape):
352+
def reshape_shapes(draw, shapes):
353+
shape = draw(shapes)
353354
size = 1 if len(shape) == 0 else math.prod(shape)
354355
rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size))
355356
assume(all(side <= MAX_SIDE for side in rshape))
@@ -359,15 +360,14 @@ def reshape_shapes(draw, shape):
359360
return tuple(rshape)
360361

361362

363+
reshape_shape = st.shared(hh.shapes(max_side=MAX_SIDE), key="reshape_shape")
364+
362365
@pytest.mark.unvectorized
363-
@pytest.mark.skip("flaky") # TODO: fix!
364366
@given(
365-
x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(max_side=MAX_SIDE)),
366-
data=st.data(),
367+
x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),
368+
shape=reshape_shapes(reshape_shape),
367369
)
368-
def test_reshape(x, data):
369-
shape = data.draw(reshape_shapes(x.shape))
370-
370+
def test_reshape(x, shape):
371371
out = xp.reshape(x, shape)
372372

373373
ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype)

0 commit comments

Comments
 (0)