Skip to content

Commit d193afa

Browse files
Implement wrap padding
1 parent d3700de commit d193afa

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

pytensor/tensor/pad.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
from collections.abc import Callable
22
from typing import Literal
33

4+
from pytensor.scan import scan
45
from pytensor.tensor import TensorLike
5-
from pytensor.tensor.basic import TensorVariable, as_tensor, zeros
6+
from pytensor.tensor.basic import (
7+
TensorVariable,
8+
as_tensor,
9+
moveaxis,
10+
zeros,
11+
)
612
from pytensor.tensor.extra_ops import broadcast_to, linspace
13+
from pytensor.tensor.math import divmod as pt_divmod
714
from pytensor.tensor.math import max as pt_max
815
from pytensor.tensor.math import mean, minimum
916
from pytensor.tensor.math import min as pt_min
@@ -12,7 +19,7 @@
1219

1320

1421
PadMode = Literal[
15-
"constant", "edge", "linear_ramp", "maximum", "minimum", "mean", "median"
22+
"constant", "edge", "linear_ramp", "maximum", "minimum", "mean", "median", "wrap"
1623
]
1724
stat_funcs = {"maximum": pt_max, "minimum": pt_min, "mean": mean}
1825

@@ -265,6 +272,28 @@ def _linear_ramp_pad(
265272
return padded
266273

267274

275+
def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
276+
pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2)))
277+
278+
for axis in range(x.ndim):
279+
size = x.shape[axis]
280+
repeats, (left_remainder, right_remainder) = pt_divmod(pad_width[axis], size)
281+
282+
left_trim = size - left_remainder
283+
right_trim = size - right_remainder
284+
total_repeats = repeats.sum() + 3 # left, right, center
285+
286+
parts, _ = scan(lambda x: x, non_sequences=[x], n_steps=total_repeats)
287+
288+
parts = moveaxis(parts, 0, axis)
289+
new_shape = [-1 if i == axis else x.shape[i] for i in range(x.ndim)]
290+
x = parts.reshape(new_shape)
291+
trim_slice = _slice_at_axis(slice(left_trim, -right_trim), axis)
292+
x = x[trim_slice]
293+
294+
return x
295+
296+
268297
def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwargs):
269298
allowed_kwargs = {
270299
"edge": [],
@@ -300,6 +329,8 @@ def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwar
300329
elif mode == "linear_ramp":
301330
end_values = kwargs.pop("end_values", 0)
302331
return _linear_ramp_pad(x, pad_width, end_values)
332+
elif mode == "wrap":
333+
return _wrap_pad(x, pad_width)
303334

304335

305336
__all__ = ["pad"]

tests/tensor/test_pad.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from pytensor.tensor.pad import PadMode, pad
66

77

8+
floatX = pytensor.config.floatX
9+
10+
811
@pytest.mark.parametrize(
912
"size", [(3,), (3, 3), (3, 3, 3)], ids=["1d", "2d square", "3d square"]
1013
)
@@ -13,7 +16,7 @@
1316
def test_constant_pad(
1417
size: tuple, constant: int | float, pad_width: int | tuple[int, ...]
1518
):
16-
x = np.random.normal(size=size)
19+
x = np.random.normal(size=size).astype(floatX)
1720
expected = np.pad(x, pad_width, mode="constant", constant_values=constant)
1821
z = pad(x, pad_width, mode="constant", constant_values=constant)
1922
f = pytensor.function([], z, mode="FAST_COMPILE")
@@ -28,7 +31,7 @@ def test_constant_pad(
2831
"pad_width", [1, (1, 2)], ids=["symmetrical", "asymmetrical_1d"]
2932
)
3033
def test_edge_pad(size: tuple, pad_width: int | tuple[int, ...]):
31-
x = np.random.normal(size=size)
34+
x = np.random.normal(size=size).astype(floatX)
3235
expected = np.pad(x, pad_width, mode="edge")
3336
z = pad(x, pad_width, mode="edge")
3437
f = pytensor.function([], z, mode="FAST_COMPILE")
@@ -48,7 +51,7 @@ def test_linear_ramp_pad(
4851
pad_width: int | tuple[int, ...],
4952
end_values: int | float | tuple[int | float, ...],
5053
):
51-
x = np.random.normal(size=size)
54+
x = np.random.normal(size=size).astype(floatX)
5255
expected = np.pad(x, pad_width, mode="linear_ramp", end_values=end_values)
5356
z = pad(x, pad_width, mode="linear_ramp", end_values=end_values)
5457
f = pytensor.function([], z, mode="FAST_COMPILE")
@@ -70,9 +73,24 @@ def test_stat_pad(
7073
stat: PadMode,
7174
stat_length: int | None,
7275
):
73-
x = np.random.normal(size=size)
76+
x = np.random.normal(size=size).astype(floatX)
7477
expected = np.pad(x, pad_width, mode=stat, stat_length=stat_length)
7578
z = pad(x, pad_width, mode=stat, stat_length=stat_length)
7679
f = pytensor.function([], z, mode="FAST_COMPILE")
7780

7881
np.testing.assert_allclose(expected, f())
82+
83+
84+
@pytest.mark.parametrize(
85+
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
86+
)
87+
@pytest.mark.parametrize(
88+
"pad_width", [1, (1, 2)], ids=["symmetrical", "asymmetrical_1d"]
89+
)
90+
def test_wrap_pad(size: tuple, pad_width: int | tuple[int, ...]):
91+
x = np.random.normal(size=size).astype(floatX)
92+
expected = np.pad(x, pad_width, mode="wrap")
93+
z = pad(x, pad_width, mode="wrap")
94+
f = pytensor.function([], z, mode="FAST_COMPILE")
95+
96+
np.testing.assert_allclose(expected, f())

0 commit comments

Comments
 (0)