@@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
555
555
556
556
def pad (
557
557
x : Array ,
558
- pad_width : int ,
558
+ pad_width : int | tuple ,
559
559
mode : str = "constant" ,
560
560
* ,
561
561
xp : ModuleType | None = None ,
@@ -568,8 +568,9 @@ def pad(
568
568
----------
569
569
x : array
570
570
Input array.
571
- pad_width : int
571
+ pad_width : int or tuple of ints
572
572
Pad the input array with this many elements from each side.
573
+ Ifa tuple, each element applies to the corresponding axis of `x`.
573
574
mode : str, optional
574
575
Only "constant" mode is currently supported, which pads with
575
576
the value passed to `constant_values`.
@@ -590,16 +591,23 @@ def pad(
590
591
591
592
value = constant_values
592
593
594
+ if isinstance (pad_width , int ):
595
+ pad_width = (pad_width ,) * x .ndim
596
+
593
597
if xp is None :
594
598
xp = array_namespace (x )
595
599
596
600
padded = xp .full (
597
- tuple (x + 2 * pad_width for x in x .shape ),
601
+ tuple (x + 2 * w for ( x , w ) in zip ( x .shape , pad_width ) ),
598
602
fill_value = value ,
599
603
dtype = x .dtype ,
600
604
device = _compat .device (x ),
601
605
)
602
- padded [(slice (pad_width , - pad_width , None ),) * x .ndim ] = x
606
+ sl = tuple (
607
+ slice (w , - w , None ) if w > 0 else slice (None , None , None )
608
+ for w in pad_width
609
+ )
610
+ padded [sl ] = x
603
611
return padded
604
612
605
613
0 commit comments