Skip to content

Commit d17fd2f

Browse files
committed
ENH: pad: add delegation
1 parent 169f21d commit d17fd2f

13 files changed

+110
-33
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
create_diagonal
1212
expand_dims
1313
kron
14+
pad
1415
setdiff1d
1516
sinc
1617
```

src/array_api_extra/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._funcs import (
3+
from ._lib._funcs import (
44
atleast_nd,
55
cov,
66
create_diagonal,

src/array_api_extra/_delegators.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Delegators to existing implementations for Public API Functions."""
2+
3+
from ._lib import _funcs
4+
from ._lib._utils._compat import (
5+
array_namespace,
6+
is_cupy_namespace,
7+
is_jax_namespace,
8+
is_numpy_namespace,
9+
is_torch_namespace,
10+
)
11+
from ._lib._utils._typing import Array, ModuleType
12+
13+
14+
def pad(
15+
x: Array,
16+
pad_width: int,
17+
mode: str = "constant",
18+
*,
19+
constant_values: bool | int | float | complex = 0,
20+
xp: ModuleType | None = None,
21+
) -> Array:
22+
"""
23+
Pad the input array.
24+
25+
Parameters
26+
----------
27+
x : array
28+
Input array.
29+
pad_width : int
30+
Pad the input array with this many elements from each side.
31+
mode : str, optional
32+
Only "constant" mode is currently supported, which pads with
33+
the value passed to `constant_values`.
34+
constant_values : python scalar, optional
35+
Use this value to pad the input. Default is zero.
36+
xp : array_namespace, optional
37+
The standard-compatible namespace for `x`. Default: infer.
38+
39+
Returns
40+
-------
41+
array
42+
The input array,
43+
padded with ``pad_width`` elements equal to ``constant_values``.
44+
"""
45+
xp = array_namespace(x) if xp is None else xp
46+
47+
value = constant_values
48+
49+
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
50+
if is_torch_namespace(xp):
51+
pad_width = xp.asarray(pad_width)
52+
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
53+
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
54+
return xp.nn.functional.pad(x, (pad_width,), value=value)
55+
56+
if is_numpy_namespace(x) or is_jax_namespace(xp) or is_cupy_namespace(xp):
57+
return xp.pad(x, pad_width, mode, constant_values=value)
58+
59+
return _funcs.pad(x, pad_width, mode, constant_values=value, xp=xp)

src/array_api_extra/_lib/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""Modules housing private functions."""
1+
"""Array-agnostic implementations for the public API."""

src/array_api_extra/_lib/_compat.py

-19
This file was deleted.

src/array_api_extra/_funcs.py renamed to src/array_api_extra/_lib/_funcs.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import warnings
44

5-
from ._lib import _compat, _utils
6-
from ._lib._compat import array_namespace
7-
from ._lib._typing import Array, ModuleType
5+
from ._utils import _compat, _helpers
6+
from ._utils._compat import array_namespace
7+
from ._utils._typing import Array, ModuleType
88

99
__all__ = [
1010
"atleast_nd",
@@ -136,7 +136,7 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
136136
m = atleast_nd(m, ndim=2, xp=xp)
137137
m = xp.astype(m, dtype)
138138

139-
avg = _utils.mean(m, axis=1, xp=xp)
139+
avg = _helpers.mean(m, axis=1, xp=xp)
140140
fact = m.shape[1] - 1
141141

142142
if fact <= 0:
@@ -449,7 +449,7 @@ def setdiff1d(
449449
else:
450450
x1 = xp.unique_values(x1)
451451
x2 = xp.unique_values(x2)
452-
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
452+
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
453453

454454

455455
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
@@ -546,8 +546,8 @@ def pad(
546546
pad_width: int,
547547
mode: str = "constant",
548548
*,
549-
xp: ModuleType | None = None,
550549
constant_values: bool | int | float | complex = 0,
550+
xp: ModuleType | None = None,
551551
) -> Array:
552552
"""
553553
Pad the input array.
@@ -561,10 +561,10 @@ def pad(
561561
mode : str, optional
562562
Only "constant" mode is currently supported, which pads with
563563
the value passed to `constant_values`.
564-
xp : array_namespace, optional
565-
The standard-compatible namespace for `x`. Default: infer.
566564
constant_values : python scalar, optional
567565
Use this value to pad the input. Default is zero.
566+
xp : array_namespace, optional
567+
The standard-compatible namespace for `x`. Default: infer.
568568
569569
Returns
570570
-------
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Modules housing private utility functions."""
+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Acquire helpers from array-api-compat."""
2+
# Allow packages that vendor both `array-api-extra` and
3+
# `array-api-compat` to override the import location
4+
5+
try:
6+
from ..._array_api_compat_vendor import ( # pyright: ignore[reportMissingImports]
7+
array_namespace, # pyright: ignore[reportUnknownVariableType]
8+
device, # pyright: ignore[reportUnknownVariableType]
9+
is_cupy_namespace, # pyright: ignore[reportUnknownVariableType]
10+
is_jax_namespace, # pyright: ignore[reportUnknownVariableType]
11+
is_numpy_namespace, # pyright: ignore[reportUnknownVariableType]
12+
is_torch_namespace, # pyright: ignore[reportUnknownVariableType]
13+
)
14+
except ImportError:
15+
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
16+
array_namespace, # pyright: ignore[reportUnknownVariableType]
17+
device,
18+
is_cupy_namespace, # pyright: ignore[reportUnknownVariableType]
19+
is_jax_namespace, # pyright: ignore[reportUnknownVariableType]
20+
is_numpy_namespace, # pyright: ignore[reportUnknownVariableType]
21+
is_torch_namespace, # pyright: ignore[reportUnknownVariableType]
22+
)
23+
24+
__all__ = [
25+
"array_namespace",
26+
"device",
27+
"is_cupy_namespace",
28+
"is_jax_namespace",
29+
"is_numpy_namespace",
30+
"is_torch_namespace",
31+
]

src/array_api_extra/_lib/_compat.pyi renamed to src/array_api_extra/_lib/_utils/_compat.pyi

+4
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ def array_namespace(
1515
use_compat: bool | None = None,
1616
) -> ArrayModule: ... # numpydoc ignore=GL08
1717
def device(x: Array, /) -> Device: ... # numpydoc ignore=GL08
18+
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
19+
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
20+
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
21+
def is_torch_namespace(xp: ModuleType, /) -> bool: ...

src/array_api_extra/_lib/_utils.py renamed to src/array_api_extra/_lib/_utils/_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Utility functions used by `array_api_extra/_funcs.py`."""
1+
"""Helper functions used by `array_api_extra/_funcs.py`."""
22

33
from . import _compat
44
from ._typing import Array, ModuleType

tests/test_funcs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
setdiff1d,
1818
sinc,
1919
)
20-
from array_api_extra._lib._typing import Array
20+
from array_api_extra._lib._utils._typing import Array
2121

2222

2323
class TestAtLeastND:

tests/test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import pytest
44
from numpy.testing import assert_array_equal
55

6-
from array_api_extra._lib._typing import Array
7-
from array_api_extra._lib._utils import in1d
6+
from array_api_extra._lib._utils._helpers import in1d
7+
from array_api_extra._lib._utils._typing import Array
88

99

1010
# some test coverage already provided by TestSetDiff1D

0 commit comments

Comments
 (0)