Skip to content

Commit 2e9a24f

Browse files
committed
ENH: allow list/tuple pad_width in pad
1 parent a96dffb commit 2e9a24f

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

src/array_api_extra/_funcs.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
555555

556556
def pad(
557557
x: Array,
558-
pad_width: int,
558+
pad_width: int | tuple | list,
559559
mode: str = "constant",
560560
*,
561561
xp: ModuleType | None = None,
@@ -568,8 +568,12 @@ def pad(
568568
----------
569569
x : array
570570
Input array.
571-
pad_width : int
571+
pad_width : int or tuple of ints or list of pairs of ints
572572
Pad the input array with this many elements from each side.
573+
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
574+
each pair applies to the corresponding axis of ``x``.
575+
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
576+
copies of this tuple.
573577
mode : str, optional
574578
Only "constant" mode is currently supported, which pads with
575579
the value passed to `constant_values`.
@@ -590,16 +594,43 @@ def pad(
590594

591595
value = constant_values
592596

597+
# make pad_width a list of length-2 tuples of ints
598+
if isinstance(pad_width, int):
599+
pad_width = [(pad_width, pad_width)] * x.ndim
600+
601+
if isinstance(pad_width, tuple):
602+
pad_width = [pad_width] * x.ndim
603+
593604
if xp is None:
594605
xp = array_namespace(x)
595606

607+
slices = []
608+
newshape = []
609+
for ax, w_tpl in enumerate(pad_width):
610+
if len(w_tpl) != 2:
611+
msg = f"expect a 2-tuple (before, after), got {w_tpl}."
612+
raise ValueError(msg)
613+
614+
sh = x.shape[ax]
615+
if w_tpl[0] == 0 and w_tpl[1] == 0:
616+
sl = slice(None, None, None)
617+
else:
618+
start, stop = w_tpl
619+
stop = None if stop == 0 else -stop
620+
621+
sl = slice(start, stop, None)
622+
sh += w_tpl[0] + w_tpl[1]
623+
624+
newshape.append(sh)
625+
slices.append(sl)
626+
596627
padded = xp.full(
597-
tuple(x + 2 * pad_width for x in x.shape),
628+
tuple(newshape),
598629
fill_value=value,
599630
dtype=x.dtype,
600631
device=_compat.device(x),
601632
)
602-
padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x
633+
padded[tuple(slices)] = x
603634
return padded
604635

605636

tests/test_funcs.py

+16
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,19 @@ def test_device(self):
416416

417417
def test_xp(self):
418418
assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))
419+
420+
def test_tuple_width(self):
421+
a = xp.reshape(xp.arange(12), (3, 4))
422+
padded = pad(a, (1, 0))
423+
assert padded.shape == (4, 5)
424+
425+
padded = pad(a, (1, 2))
426+
assert padded.shape == (6, 7)
427+
428+
def test_list_of_tuples_width(self):
429+
a = xp.reshape(xp.arange(12), (3, 4))
430+
padded = pad(a, [(1, 0), (0, 2)])
431+
assert padded.shape == (4, 6)
432+
433+
padded = pad(a, [(1, 0), (0, 0)])
434+
assert padded.shape == (4, 4)

0 commit comments

Comments
 (0)