@@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
555
555
556
556
def pad (
557
557
x : Array ,
558
- pad_width : int ,
558
+ pad_width : int | tuple | list ,
559
559
mode : str = "constant" ,
560
560
* ,
561
561
xp : ModuleType | None = None ,
@@ -568,8 +568,12 @@ def pad(
568
568
----------
569
569
x : array
570
570
Input array.
571
- pad_width : int
571
+ pad_width : int or tuple of ints or list of pairs of ints
572
572
Pad the input array with this many elements from each side.
573
+ If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
574
+ each pair applies to the corresponding axis of ``x``.
575
+ A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
576
+ copies of this tuple.
573
577
mode : str, optional
574
578
Only "constant" mode is currently supported, which pads with
575
579
the value passed to `constant_values`.
@@ -590,16 +594,43 @@ def pad(
590
594
591
595
value = constant_values
592
596
597
+ # make pad_width a list of length-2 tuples of ints
598
+ if isinstance (pad_width , int ):
599
+ pad_width = [(pad_width , pad_width )] * x .ndim
600
+
601
+ if isinstance (pad_width , tuple ):
602
+ pad_width = [pad_width ] * x .ndim
603
+
593
604
if xp is None :
594
605
xp = array_namespace (x )
595
606
607
+ slices = []
608
+ newshape = []
609
+ for ax , w_tpl in enumerate (pad_width ):
610
+ if len (w_tpl ) != 2 :
611
+ msg = f"expect a 2-tuple (before, after), got { w_tpl } ."
612
+ raise ValueError (msg )
613
+
614
+ sh = x .shape [ax ]
615
+ if w_tpl [0 ] == 0 and w_tpl [1 ] == 0 :
616
+ sl = slice (None , None , None )
617
+ else :
618
+ start , stop = w_tpl
619
+ stop = None if stop == 0 else - stop
620
+
621
+ sl = slice (start , stop , None )
622
+ sh += w_tpl [0 ] + w_tpl [1 ]
623
+
624
+ newshape .append (sh )
625
+ slices .append (sl )
626
+
596
627
padded = xp .full (
597
- tuple (x + 2 * pad_width for x in x . shape ),
628
+ tuple (newshape ),
598
629
fill_value = value ,
599
630
dtype = x .dtype ,
600
631
device = _compat .device (x ),
601
632
)
602
- padded [( slice ( pad_width , - pad_width , None ),) * x . ndim ] = x
633
+ padded [tuple ( slices ) ] = x
603
634
return padded
604
635
605
636
0 commit comments