File tree 2 files changed +12
-1
lines changed
2 files changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -582,7 +582,10 @@ def pad(
582
582
xp = array_namespace (x )
583
583
584
584
padded = xp .full (
585
- tuple (x + 2 * pad_width for x in x .shape ), fill_value = value , dtype = x .dtype
585
+ tuple (x + 2 * pad_width for x in x .shape ),
586
+ fill_value = value ,
587
+ dtype = x .dtype ,
588
+ device = _compat .device (x ),
586
589
)
587
590
padded [(slice (pad_width , - pad_width , None ),) * x .ndim ] = x
588
591
return padded
Original file line number Diff line number Diff line change @@ -408,3 +408,11 @@ def test_mode_not_implemented(self):
408
408
a = xp .arange (3 )
409
409
with pytest .raises (NotImplementedError , match = "Only `'constant'`" ):
410
410
pad (a , 2 , mode = "edge" )
411
+
412
+ def test_device (self ):
413
+ device = xp .Device ("device1" )
414
+ a = xp .asarray (0.0 , device = device )
415
+ assert pad (a , 2 ).device == device
416
+
417
+ def test_xp (self ):
418
+ assert_array_equal (pad (xp .asarray (0 ), 1 , xp = xp ), xp .zeros (3 ))
You can’t perform that action at this time.
0 commit comments