|
4 | 4 |
|
5 | 5 | from ._lib import _compat, _utils
|
6 | 6 | from ._lib._compat import (
|
7 |
| - array_namespace, is_torch_namespace, is_array_api_strict_namespace |
| 7 | + array_namespace, |
8 | 8 | )
|
9 | 9 | from ._lib._typing import Array, ModuleType
|
10 | 10 |
|
|
14 | 14 | "create_diagonal",
|
15 | 15 | "expand_dims",
|
16 | 16 | "kron",
|
| 17 | + "pad", |
17 | 18 | "setdiff1d",
|
18 | 19 | "sinc",
|
19 |
| - "pad", |
20 | 20 | ]
|
21 | 21 |
|
22 | 22 |
|
@@ -543,52 +543,46 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
|
543 | 543 | return xp.sin(y) / y
|
544 | 544 |
|
545 | 545 |
|
546 |
| -def pad(x: Array, pad_width: int, mode: str = 'constant', *, xp: ModuleType = None, **kwargs): |
| 546 | +def pad( |
| 547 | + x: Array, |
| 548 | + pad_width: int, |
| 549 | + mode: str = "constant", |
| 550 | + *, |
| 551 | + xp: ModuleType | None = None, |
| 552 | + constant_values: bool | int | float | complex = 0, |
| 553 | +) -> Array: |
547 | 554 | """
|
548 | 555 | Pad the input array.
|
549 | 556 |
|
550 | 557 | Parameters
|
551 | 558 | ----------
|
552 | 559 | x : array
|
553 |
| - Input array |
554 |
| - pad_width: int |
555 |
| - Pad the input array with this many elements from each side |
556 |
| - mode: str, optional |
| 560 | + Input array. |
| 561 | + pad_width : int |
| 562 | + Pad the input array with this many elements from each side. |
| 563 | + mode : str, optional |
557 | 564 | Only "constant" mode is currently supported.
|
558 | 565 | xp : array_namespace, optional
|
559 | 566 | The standard-compatible namespace for `x`. Default: infer.
|
560 |
| - constant_values: python scalar, optional |
| 567 | + constant_values : python scalar, optional |
561 | 568 | Use this value to pad the input. Default is zero.
|
562 | 569 |
|
563 | 570 | Returns
|
564 | 571 | -------
|
565 | 572 | array
|
566 |
| - The input array, padded with ``pad_width`` elements equal to ``constant_values`` |
| 573 | + The input array, |
| 574 | + padded with ``pad_width`` elements equal to ``constant_values``. |
567 | 575 | """
|
568 |
| - # xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse |
569 |
| - # http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045 |
570 |
| - |
571 |
| - if mode != 'constant': |
| 576 | + if mode != "constant": |
572 | 577 | raise NotImplementedError()
|
573 | 578 |
|
574 |
| - value = kwargs.get("constant_values", 0) |
575 |
| - if kwargs and list(kwargs.keys()) != ['constant_values']: |
576 |
| - raise ValueError(f"Unknown kwargs: {kwargs}") |
| 579 | + value = constant_values |
577 | 580 |
|
578 | 581 | if xp is None:
|
579 | 582 | xp = array_namespace(x)
|
580 | 583 |
|
581 |
| - if is_array_api_strict_namespace(xp): |
582 |
| - padded = xp.full( |
583 |
| - tuple(x + 2*pad_width for x in x.shape), fill_value=value, dtype=x.dtype |
584 |
| - ) |
585 |
| - padded[(slice(pad_width, -pad_width, None),)*x.ndim] = x |
586 |
| - return padded |
587 |
| - elif is_torch_namespace(xp): |
588 |
| - pad_width = xp.asarray(pad_width) |
589 |
| - pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) |
590 |
| - pad_width = xp.flip(pad_width, axis=(0,)).flatten() |
591 |
| - return xp.nn.functional.pad(x, tuple(pad_width), value=value) |
592 |
| - |
593 |
| - else: |
594 |
| - return xp.pad(x, pad_width, mode=mode, **kwargs) |
| 584 | + padded = xp.full( |
| 585 | + tuple(x + 2 * pad_width for x in x.shape), fill_value=value, dtype=x.dtype |
| 586 | + ) |
| 587 | + padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x |
| 588 | + return padded |
0 commit comments