Skip to content

Commit 5aac0b4

Browse files
committed
remove delegation for now
1 parent 96bfbbf commit 5aac0b4

File tree

4 files changed

+37
-40
lines changed

4 files changed

+37
-40
lines changed

src/array_api_extra/__init__.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc, pad
3+
from ._funcs import (
4+
atleast_nd,
5+
cov,
6+
create_diagonal,
7+
expand_dims,
8+
kron,
9+
pad,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.4.1.dev0"
615

@@ -12,7 +21,7 @@
1221
"create_diagonal",
1322
"expand_dims",
1423
"kron",
24+
"pad",
1525
"setdiff1d",
1626
"sinc",
17-
"pad",
1827
]

src/array_api_extra/_funcs.py

+24-30
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._lib import _compat, _utils
66
from ._lib._compat import (
7-
array_namespace, is_torch_namespace, is_array_api_strict_namespace
7+
array_namespace,
88
)
99
from ._lib._typing import Array, ModuleType
1010

@@ -14,9 +14,9 @@
1414
"create_diagonal",
1515
"expand_dims",
1616
"kron",
17+
"pad",
1718
"setdiff1d",
1819
"sinc",
19-
"pad",
2020
]
2121

2222

@@ -543,52 +543,46 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
543543
return xp.sin(y) / y
544544

545545

546-
def pad(x: Array, pad_width: int, mode: str = 'constant', *, xp: ModuleType = None, **kwargs):
546+
def pad(
547+
x: Array,
548+
pad_width: int,
549+
mode: str = "constant",
550+
*,
551+
xp: ModuleType | None = None,
552+
constant_values: bool | int | float | complex = 0,
553+
) -> Array:
547554
"""
548555
Pad the input array.
549556
550557
Parameters
551558
----------
552559
x : array
553-
Input array
554-
pad_width: int
555-
Pad the input array with this many elements from each side
556-
mode: str, optional
560+
Input array.
561+
pad_width : int
562+
Pad the input array with this many elements from each side.
563+
mode : str, optional
557564
Only "constant" mode is currently supported.
558565
xp : array_namespace, optional
559566
The standard-compatible namespace for `x`. Default: infer.
560-
constant_values: python scalar, optional
567+
constant_values : python scalar, optional
561568
Use this value to pad the input. Default is zero.
562569
563570
Returns
564571
-------
565572
array
566-
The input array, padded with ``pad_width`` elements equal to ``constant_values``
573+
The input array,
574+
padded with ``pad_width`` elements equal to ``constant_values``.
567575
"""
568-
# xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse
569-
# http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045
570-
571-
if mode != 'constant':
576+
if mode != "constant":
572577
raise NotImplementedError()
573578

574-
value = kwargs.get("constant_values", 0)
575-
if kwargs and list(kwargs.keys()) != ['constant_values']:
576-
raise ValueError(f"Unknown kwargs: {kwargs}")
579+
value = constant_values
577580

578581
if xp is None:
579582
xp = array_namespace(x)
580583

581-
if is_array_api_strict_namespace(xp):
582-
padded = xp.full(
583-
tuple(x + 2*pad_width for x in x.shape), fill_value=value, dtype=x.dtype
584-
)
585-
padded[(slice(pad_width, -pad_width, None),)*x.ndim] = x
586-
return padded
587-
elif is_torch_namespace(xp):
588-
pad_width = xp.asarray(pad_width)
589-
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
590-
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
591-
return xp.nn.functional.pad(x, tuple(pad_width), value=value)
592-
593-
else:
594-
return xp.pad(x, pad_width, mode=mode, **kwargs)
584+
padded = xp.full(
585+
tuple(x + 2 * pad_width for x in x.shape), fill_value=value, dtype=x.dtype
586+
)
587+
padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x
588+
return padded

src/array_api_extra/_lib/_compat.py

-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from array_api_compat import ( # pyright: ignore[reportMissingTypeStubs]
1212
array_namespace, # pyright: ignore[reportUnknownVariableType]
1313
device,
14-
is_torch_namespace,
15-
is_array_api_strict_namespace,
1614
)
1715

1816
__all__ = [

tests/test_funcs.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
create_diagonal,
1414
expand_dims,
1515
kron,
16+
pad,
1617
setdiff1d,
1718
sinc,
18-
pad,
1919
)
2020
from array_api_extra._lib._typing import Array
2121

@@ -400,10 +400,6 @@ def test_fill_value(self):
400400
assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42]))
401401

402402
def test_ndim(self):
403-
a = xp.reshape(xp.arange(2*3*4), (2, 3, 4))
403+
a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4))
404404
padded = pad(a, 2)
405405
assert padded.shape == (6, 7, 8)
406-
407-
def test_typo(self):
408-
with pytest.raises(ValueError, match="Unknown"):
409-
pad(xp.arange(2), pad_width=3, oops=3)

0 commit comments

Comments
 (0)