@@ -300,31 +300,53 @@ def test_permute_dims(x, axes):
300
300
def test_repeat (x , kw , data ):
301
301
shape = x .shape
302
302
axis = kw .get ("axis" , None )
303
- dim = math .prod (shape ) if axis is None else shape [axis ]
304
- repeat_strat = st .integers (1 , 4 )
303
+ size = math .prod (shape ) if axis is None else shape [axis ]
304
+ repeat_strat = st .integers (1 , 10 )
305
305
repeats = data .draw (repeat_strat
306
306
| hh .arrays (dtype = hh .int_dtypes , elements = repeat_strat ,
307
- shape = st .sampled_from ([(1 ,), (dim ,)])),
307
+ shape = st .sampled_from ([(1 ,), (size ,)])),
308
308
label = "repeats" )
309
309
if isinstance (repeats , int ):
310
- n_repitions = dim * repeats
310
+ n_repititions = size * repeats
311
311
else :
312
312
if repeats .shape == (1 ,):
313
- n_repitions = dim * repeats [0 ]
313
+ n_repititions = size * int ( repeats [0 ])
314
314
else :
315
- n_repitions = int (xp .sum (repeats ))
315
+ n_repititions = int (xp .sum (repeats ))
316
+
317
+ assume (n_repititions <= hh .SQRT_MAX_ARRAY_SIZE )
316
318
317
319
out = xp .repeat (x , repeats , ** kw )
318
320
ph .assert_dtype ("repeat" , in_dtype = x .dtype , out_dtype = out .dtype )
319
321
if axis is None :
320
- expected_shape = (n_repitions ,)
322
+ expected_shape = (n_repititions ,)
321
323
else :
322
324
expected_shape = list (shape )
323
- expected_shape [axis ] = n_repitions
325
+ expected_shape [axis ] = n_repititions
324
326
expected_shape = tuple (expected_shape )
325
327
ph .assert_shape ("repeat" , out_shape = out .shape , expected = expected_shape )
326
- # TODO: values testing
327
328
329
+ # Test values
330
+
331
+ if isinstance (repeats , int ):
332
+ repeats_array = xp .full (size , repeats , dtype = xp .int32 )
333
+ else :
334
+ repeats_array = repeats
335
+
336
+ if kw .get ("axis" ) is None :
337
+ x = xp .reshape (x , (- 1 ,))
338
+ axis = 0
339
+
340
+ for idx , in sh .iter_indices (x .shape , skip_axes = axis ):
341
+ x_slice = x [idx ]
342
+ out_slice = out [idx ]
343
+ start = 0
344
+ for i , count in enumerate (repeats_array ):
345
+ end = start + count
346
+ ph .assert_array_elements ("repeat" , out = out_slice [start :end ],
347
+ expected = xp .full ((count ,), x_slice [i ], dtype = x .dtype ),
348
+ kw = kw )
349
+ start = end
328
350
329
351
@st .composite
330
352
def reshape_shapes (draw , shape ):
0 commit comments