Skip to content

Commit d3700de

Browse files
Implement linear_ramp_pad and stat_pad
1 parent 5b1b2d6 commit d3700de

File tree

2 files changed

+288
-38
lines changed

2 files changed

+288
-38
lines changed

pytensor/tensor/pad.py

+210-38
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,91 @@
1-
from numpy.lib.arraypad import _get_edges, _slice_at_axis # noqa
2-
3-
from pytensor.tensor.basic import (
4-
TensorVariable,
5-
as_tensor,
6-
swapaxes,
7-
zeros,
8-
)
9-
from pytensor.tensor.extra_ops import linspace, broadcast_to
1+
from collections.abc import Callable
2+
from typing import Literal
3+
4+
from pytensor.tensor import TensorLike
5+
from pytensor.tensor.basic import TensorVariable, as_tensor, zeros
6+
from pytensor.tensor.extra_ops import broadcast_to, linspace
7+
from pytensor.tensor.math import max as pt_max
8+
from pytensor.tensor.math import mean, minimum
9+
from pytensor.tensor.math import min as pt_min
1010
from pytensor.tensor.shape import specify_broadcastable
1111
from pytensor.tensor.subtensor import set_subtensor
1212

1313

14+
PadMode = Literal[
15+
"constant", "edge", "linear_ramp", "maximum", "minimum", "mean", "median"
16+
]
17+
stat_funcs = {"maximum": pt_max, "minimum": pt_min, "mean": mean}
18+
19+
20+
def _slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]:
21+
"""
22+
Construct tuple of slices to slice an array in the given dimension.
23+
24+
Copied from numpy.lib.arraypad._slice_at_axis
25+
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L33
26+
27+
Parameters
28+
----------
29+
sl : slice
30+
The slice for the given dimension.
31+
axis : int
32+
The axis to which `sl` is applied. All other dimensions are left
33+
"unsliced".
34+
35+
Returns
36+
-------
37+
sl : tuple of slices
38+
A tuple with slices matching `shape` in length.
39+
40+
Examples
41+
--------
42+
>>> _slice_at_axis(slice(None, 3, -1), 1)
43+
(slice(None, None, None), slice(None, 3, -1), (...,))
44+
"""
45+
return (slice(None),) * axis + (sl,) + (...,) # type: ignore
46+
47+
48+
def _get_edges(
49+
padded: TensorVariable, axis: int, width_pair: tuple[TensorVariable, TensorVariable]
50+
) -> tuple[TensorVariable, TensorVariable]:
51+
"""
52+
Retrieve edge values from empty-padded array in given dimension.
53+
54+
Copied from numpy.lib.arraypad._get_edges
55+
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L154
56+
57+
Parameters
58+
----------
59+
padded : TensorVariable
60+
Empty-padded array.
61+
axis : int
62+
Dimension in which the edges are considered.
63+
width_pair : (TensorVariable, TensorVariable)
64+
Pair of widths that mark the pad area on both sides in the given
65+
dimension.
66+
67+
Returns
68+
-------
69+
left_edge, right_edge : TensorVariable
70+
Edge values of the valid area in `padded` in the given dimension. Its
71+
shape will always match `padded` except for the dimension given by
72+
`axis` which will have a length of 1.
73+
"""
74+
left_index = width_pair[0]
75+
left_slice = _slice_at_axis(slice(left_index, left_index + 1), axis)
76+
left_edge = padded[left_slice]
77+
78+
right_index = padded.shape[axis] - width_pair[1]
79+
right_slice = _slice_at_axis(slice(right_index - 1, right_index), axis)
80+
right_edge = padded[right_slice]
81+
82+
return left_edge, right_edge
83+
84+
1485
def _symbolic_pad(
1586
x: TensorVariable, pad_width: TensorVariable
1687
) -> tuple[TensorVariable, tuple[slice, ...], TensorVariable]:
17-
pad_width = broadcast_to(pad_width, (x.ndim, 2))
88+
pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2)))
1889
new_shape = as_tensor(
1990
[pad_width[i][0] + size + pad_width[i][1] for i, size in enumerate(x.shape)]
2091
)
@@ -26,8 +97,10 @@ def _symbolic_pad(
2697

2798

2899
def _get_padding_slices(
29-
dim_shape: TensorVariable, width_pair: tuple[TensorVariable, TensorVariable], axis: int
30-
):
100+
dim_shape: TensorVariable,
101+
width_pair: tuple[TensorVariable, TensorVariable],
102+
axis: int,
103+
) -> tuple[tuple[slice, ...], tuple[slice, ...]]:
31104
left_slice = _slice_at_axis(slice(None, width_pair[0]), axis)
32105
right_slice = _slice_at_axis(slice(dim_shape - width_pair[1], None), axis)
33106

@@ -36,9 +109,9 @@ def _get_padding_slices(
36109

37110
def _constant_pad(
38111
x: TensorVariable, pad_width: TensorVariable, constant_values: TensorVariable
39-
):
112+
) -> TensorVariable:
40113
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
41-
values = broadcast_to(constant_values, (padded.ndim, 2))
114+
values = broadcast_to(constant_values, as_tensor((padded.ndim, 2)))
42115

43116
for axis in range(padded.ndim):
44117
width_pair = pad_width[axis]
@@ -52,7 +125,7 @@ def _constant_pad(
52125
return padded
53126

54127

55-
def _edge_pad(x: TensorVariable, pad_width: TensorVariable):
128+
def _edge_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
56129
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
57130
for axis in range(padded.ndim):
58131
width_pair = pad_width[axis]
@@ -67,42 +140,133 @@ def _edge_pad(x: TensorVariable, pad_width: TensorVariable):
67140
return padded
68141

69142

143+
def _get_stats(
144+
padded: TensorVariable,
145+
axis: int,
146+
width_pair: TensorVariable,
147+
length_pair: tuple[TensorVariable, TensorVariable] | tuple[None, None],
148+
stat_func: Callable,
149+
):
150+
"""
151+
Calculate statistic for the empty-padded array in given dimension.
152+
153+
Copied from numpy.lib.arraypad._get_stats
154+
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L230
155+
156+
Parameters
157+
----------
158+
padded : TensorVariable
159+
Empty-padded array.
160+
axis : int
161+
Dimension in which the statistic is calculated.
162+
width_pair : (TensorVariable, TensorVariable)
163+
Pair of widths that mark the pad area on both sides in the given dimension.
164+
length_pair : 2-element sequence of None or TensorVariable
165+
Gives the number of values in valid area from each side that is taken into account when calculating the
166+
statistic. If None the entire valid area in `padded` is considered.
167+
stat_func : function
168+
Function to compute statistic. The expected signature is
169+
``stat_func(x: TensorVariable, axis: int, keepdims: bool) -> TensorVariable``.
170+
171+
Returns
172+
-------
173+
left_stat, right_stat : TensorVariable
174+
Calculated statistic for both sides of `padded`.
175+
"""
176+
# Calculate indices of the edges of the area with original values
177+
left_index = width_pair[0]
178+
right_index = padded.shape[axis] - width_pair[1]
179+
# as well as its length
180+
max_length = right_index - left_index
181+
182+
# Limit stat_lengths to max_length
183+
left_length, right_length = length_pair
184+
185+
# Calculate statistic for the left side
186+
left_length = (
187+
minimum(left_length, max_length) if left_length is not None else max_length
188+
)
189+
left_slice = _slice_at_axis(slice(left_index, left_index + left_length), axis)
190+
left_chunk = padded[left_slice]
191+
left_stat = stat_func(left_chunk, axis=axis, keepdims=True)
192+
if left_length is None and right_length is None:
193+
# We could also return early in the more general case of left_length == right_length, but we don't necessarily
194+
# know these shapes.
195+
# TODO: Add rewrite to simplify in this case
196+
return left_stat, left_stat
197+
198+
# Calculate statistic for the right side
199+
right_length = (
200+
minimum(right_length, max_length) if right_length is not None else max_length
201+
)
202+
right_slice = _slice_at_axis(slice(right_index - right_length, right_index), axis)
203+
right_chunk = padded[right_slice]
204+
right_stat = stat_func(right_chunk, axis=axis, keepdims=True)
205+
206+
return left_stat, right_stat
207+
208+
209+
def _stat_pad(
210+
x: TensorVariable, pad_width: TensorVariable, stat_func, stat_length=None
211+
):
212+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
213+
if stat_length is None:
214+
stat_length = [[None, None]] * padded.ndim
215+
else:
216+
stat_length = broadcast_to(stat_length, as_tensor((padded.ndim, 2)))
217+
218+
for axis in range(padded.ndim):
219+
width_pair = pad_width[axis]
220+
length_pair = stat_length[axis]
221+
dim_shape = padded.shape[axis]
222+
223+
left_stat, right_stat = _get_stats(
224+
padded, axis, width_pair, length_pair, stat_func
225+
)
226+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
227+
padded = set_subtensor(padded[left_slice], left_stat)
228+
padded = set_subtensor(padded[right_slice], right_stat)
229+
230+
return padded
231+
232+
70233
def _linear_ramp_pad(
71234
x: TensorVariable, pad_width: TensorVariable, end_values: TensorVariable | int = 0
72-
):
235+
) -> TensorVariable:
73236
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
74-
end_values = broadcast_to(end_values, (padded.ndim, 2))
237+
end_values = as_tensor(end_values)
238+
end_values = broadcast_to(end_values, as_tensor((padded.ndim, 2)))
239+
75240
for axis in range(padded.ndim):
76241
width_pair = pad_width[axis]
77242
end_value_pair = end_values[axis]
78243
edge_pair = _get_edges(padded, axis, width_pair)
79244
dim_shape = padded.shape[axis]
80245
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
81246

82-
# pt.linspace doesn't have the endpoint kwarg, so need to take one extra step then slice it away
83-
left_ramp = linspace(
84-
start=end_value_pair[0],
85-
end=specify_broadcastable(edge_pair[0], axis).squeeze(axis),
86-
steps=width_pair[0] + 1,
87-
)[:-1]
88-
right_ramp = linspace(
89-
start=end_value_pair[1],
90-
end=specify_broadcastable(edge_pair[1], axis).squeeze(axis),
91-
steps=width_pair[1] + 1,
92-
)[:-1]
93-
right_ramp = right_ramp[_slice_at_axis(slice(None, None, -1), axis)]
94-
95-
# FIXME: This swapaxes is needed because the shapes of the linspaces don't "rotate" with
96-
# the different dimensions. But this makes the non-active dimensions backwards in the padding.
97-
padded = set_subtensor(padded[left_slice], swapaxes(left_ramp, 0, axis))
98-
padded = set_subtensor(padded[right_slice], swapaxes(right_ramp, 0, axis))
247+
left_ramp, right_ramp = (
248+
linspace(
249+
start=end_value,
250+
stop=specify_broadcastable(edge, axis).squeeze(axis),
251+
num=width,
252+
endpoint=False,
253+
dtype=padded.dtype,
254+
axis=axis,
255+
)
256+
for end_value, edge, width in zip(end_value_pair, edge_pair, width_pair)
257+
)
258+
259+
# Reverse the direction of the ramp for the "right" side
260+
right_ramp = right_ramp[_slice_at_axis(slice(None, None, -1), axis)] # type: ignore
261+
262+
padded = set_subtensor(padded[left_slice], left_ramp)
263+
padded = set_subtensor(padded[right_slice], right_ramp)
99264

100265
return padded
101266

102267

103-
def pad(x, pad_width, mode="constant", **kwargs):
268+
def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwargs):
104269
allowed_kwargs = {
105-
"empty": [],
106270
"edge": [],
107271
"wrap": [],
108272
"constant": ["constant_values"],
@@ -115,16 +279,24 @@ def pad(x, pad_width, mode="constant", **kwargs):
115279
"symmetric": ["reflect_type"],
116280
}
117281

118-
if any(value not in allowed_kwargs[mode] for value in kwargs.values()):
282+
if any(value not in allowed_kwargs[mode] for value in kwargs.keys()):
119283
raise ValueError(
120284
f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}"
121285
)
286+
x = as_tensor(x)
287+
pad_width = as_tensor(pad_width)
122288

123289
if mode == "constant":
124-
constant_values = kwargs.pop("constant_values", 0)
290+
constant_values = as_tensor(kwargs.pop("constant_values", 0))
125291
return _constant_pad(x, pad_width, constant_values)
126292
elif mode == "edge":
127293
return _edge_pad(x, pad_width)
294+
elif mode in ["maximum", "minimum", "mean", "median"]:
295+
if mode == "median":
296+
# TODO: pt.quantile? pt.median?
297+
raise NotImplementedError("Median padding not implemented")
298+
stat_func = stat_funcs[mode]
299+
return _stat_pad(x, pad_width, stat_func, **kwargs)
128300
elif mode == "linear_ramp":
129301
end_values = kwargs.pop("end_values", 0)
130302
return _linear_ramp_pad(x, pad_width, end_values)

tests/tensor/test_pad.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor
5+
from pytensor.tensor.pad import PadMode, pad
6+
7+
8+
@pytest.mark.parametrize(
9+
"size", [(3,), (3, 3), (3, 3, 3)], ids=["1d", "2d square", "3d square"]
10+
)
11+
@pytest.mark.parametrize("constant", [0, 0.0], ids=["int", "float"])
12+
@pytest.mark.parametrize("pad_width", [1, (1, 2)], ids=["symmetrical", "asymmetrical"])
13+
def test_constant_pad(
14+
size: tuple, constant: int | float, pad_width: int | tuple[int, ...]
15+
):
16+
x = np.random.normal(size=size)
17+
expected = np.pad(x, pad_width, mode="constant", constant_values=constant)
18+
z = pad(x, pad_width, mode="constant", constant_values=constant)
19+
f = pytensor.function([], z, mode="FAST_COMPILE")
20+
21+
np.testing.assert_allclose(expected, f())
22+
23+
24+
@pytest.mark.parametrize(
25+
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
26+
)
27+
@pytest.mark.parametrize(
28+
"pad_width", [1, (1, 2)], ids=["symmetrical", "asymmetrical_1d"]
29+
)
30+
def test_edge_pad(size: tuple, pad_width: int | tuple[int, ...]):
31+
x = np.random.normal(size=size)
32+
expected = np.pad(x, pad_width, mode="edge")
33+
z = pad(x, pad_width, mode="edge")
34+
f = pytensor.function([], z, mode="FAST_COMPILE")
35+
36+
np.testing.assert_allclose(expected, f())
37+
38+
39+
@pytest.mark.parametrize(
40+
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
41+
)
42+
@pytest.mark.parametrize(
43+
"pad_width", [1, (1, 2)], ids=["symmetrical", "asymmetrical_1d"]
44+
)
45+
@pytest.mark.parametrize("end_values", [0, -1], ids=["0", "-1"])
46+
def test_linear_ramp_pad(
47+
size: tuple,
48+
pad_width: int | tuple[int, ...],
49+
end_values: int | float | tuple[int | float, ...],
50+
):
51+
x = np.random.normal(size=size)
52+
expected = np.pad(x, pad_width, mode="linear_ramp", end_values=end_values)
53+
z = pad(x, pad_width, mode="linear_ramp", end_values=end_values)
54+
f = pytensor.function([], z, mode="FAST_COMPILE")
55+
56+
np.testing.assert_allclose(expected, f())
57+
58+
59+
@pytest.mark.parametrize(
60+
"size", [(3,), (3, 3), (3, 5, 5)], ids=["1d", "2d square", "3d square"]
61+
)
62+
@pytest.mark.parametrize(
63+
"pad_width", [1, (1, 2)], ids=["symmetrical", "asymmetrical_1d"]
64+
)
65+
@pytest.mark.parametrize("stat", ["mean", "minimum", "maximum"])
66+
@pytest.mark.parametrize("stat_length", [None, 2])
67+
def test_stat_pad(
68+
size: tuple,
69+
pad_width: int | tuple[int, ...],
70+
stat: PadMode,
71+
stat_length: int | None,
72+
):
73+
x = np.random.normal(size=size)
74+
expected = np.pad(x, pad_width, mode=stat, stat_length=stat_length)
75+
z = pad(x, pad_width, mode=stat, stat_length=stat_length)
76+
f = pytensor.function([], z, mode="FAST_COMPILE")
77+
78+
np.testing.assert_allclose(expected, f())

0 commit comments

Comments
 (0)