Skip to content

Commit 21cc1bd

Browse files
committed
ENH: allow tuple pad_width in pad
1 parent a96dffb commit 21cc1bd

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

src/array_api_extra/_funcs.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
555555

556556
def pad(
557557
x: Array,
558-
pad_width: int,
558+
pad_width: int | tuple,
559559
mode: str = "constant",
560560
*,
561561
xp: ModuleType | None = None,
@@ -568,8 +568,9 @@ def pad(
568568
----------
569569
x : array
570570
Input array.
571-
pad_width : int
571+
pad_width : int or tuple of ints
572572
Pad the input array with this many elements from each side.
573+
Ifa tuple, each element applies to the corresponding axis of `x`.
573574
mode : str, optional
574575
Only "constant" mode is currently supported, which pads with
575576
the value passed to `constant_values`.
@@ -590,16 +591,23 @@ def pad(
590591

591592
value = constant_values
592593

594+
if isinstance(pad_width, int):
595+
pad_width = (pad_width,) * x.ndim
596+
593597
if xp is None:
594598
xp = array_namespace(x)
595599

596600
padded = xp.full(
597-
tuple(x + 2 * pad_width for x in x.shape),
601+
tuple(x + 2 * w for (x, w) in zip(x.shape, pad_width)),
598602
fill_value=value,
599603
dtype=x.dtype,
600604
device=_compat.device(x),
601605
)
602-
padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x
606+
sl = tuple(
607+
slice(w, -w, None) if w > 0 else slice(None, None, None)
608+
for w in pad_width
609+
)
610+
padded[sl] = x
603611
return padded
604612

605613

tests/test_funcs.py

+8
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,11 @@ def test_device(self):
416416

417417
def test_xp(self):
418418
assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))
419+
420+
def test_tuple_width(self):
421+
a = xp.reshape(xp.arange(12), (3, 4))
422+
padded = pad(a, (1, 0))
423+
assert padded.shape == (5, 4)
424+
425+
padded = pad(a, (1, 2))
426+
assert padded.shape == (5, 8)

0 commit comments

Comments
 (0)