@@ -349,7 +349,8 @@ def test_repeat(x, kw, data):
349
349
start = end
350
350
351
351
@st .composite
352
- def reshape_shapes (draw , shape ):
352
+ def reshape_shapes (draw , shapes ):
353
+ shape = draw (shapes )
353
354
size = 1 if len (shape ) == 0 else math .prod (shape )
354
355
rshape = draw (st .lists (st .integers (0 )).filter (lambda s : math .prod (s ) == size ))
355
356
assume (all (side <= MAX_SIDE for side in rshape ))
@@ -359,15 +360,14 @@ def reshape_shapes(draw, shape):
359
360
return tuple (rshape )
360
361
361
362
363
+ reshape_shape = st .shared (hh .shapes (max_side = MAX_SIDE ), key = "reshape_shape" )
364
+
362
365
@pytest .mark .unvectorized
363
- @pytest .mark .skip ("flaky" ) # TODO: fix!
364
366
@given (
365
- x = hh .arrays (dtype = hh .all_dtypes , shape = hh . shapes ( max_side = MAX_SIDE ) ),
366
- data = st . data ( ),
367
+ x = hh .arrays (dtype = hh .all_dtypes , shape = reshape_shape ),
368
+ shape = reshape_shapes ( reshape_shape ),
367
369
)
368
- def test_reshape (x , data ):
369
- shape = data .draw (reshape_shapes (x .shape ))
370
-
370
+ def test_reshape (x , shape ):
371
371
out = xp .reshape (x , shape )
372
372
373
373
ph .assert_dtype ("reshape" , in_dtype = x .dtype , out_dtype = out .dtype )
0 commit comments