File tree Expand file tree Collapse file tree 2 files changed +12
-1
lines changed Expand file tree Collapse file tree 2 files changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -582,7 +582,10 @@ def pad(
582582 xp = array_namespace (x )
583583
584584 padded = xp .full (
585- tuple (x + 2 * pad_width for x in x .shape ), fill_value = value , dtype = x .dtype
585+ tuple (x + 2 * pad_width for x in x .shape ),
586+ fill_value = value ,
587+ dtype = x .dtype ,
588+ device = _compat .device (x ),
586589 )
587590 padded [(slice (pad_width , - pad_width , None ),) * x .ndim ] = x
588591 return padded
Original file line number Diff line number Diff line change @@ -408,3 +408,11 @@ def test_mode_not_implemented(self):
408408 a = xp .arange (3 )
409409 with pytest .raises (NotImplementedError , match = "Only `'constant'`" ):
410410 pad (a , 2 , mode = "edge" )
411+
412+ def test_device (self ):
413+ device = xp .Device ("device1" )
414+ a = xp .asarray (0.0 , device = device )
415+ assert pad (a , 2 ).device == device
416+
417+ def test_xp (self ):
418+ assert_array_equal (pad (xp .asarray (0 ), 1 , xp = xp ), xp .zeros (3 ))
You can’t perform that action at this time.
0 commit comments