Skip to content

Commit 3186740

Browse files
Add pt.pad
1 parent 2a4d081 commit 3186740

File tree

3 files changed

+508
-0
lines changed

3 files changed

+508
-0
lines changed

pytensor/tensor/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
130130
from pytensor.tensor.extra_ops import *
131131
from pytensor.tensor.io import *
132132
from pytensor.tensor.math import *
133+
from pytensor.tensor.pad import pad
133134
from pytensor.tensor.shape import (
134135
reshape,
135136
shape,

pytensor/tensor/pad.py

+384
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
from collections.abc import Callable
2+
from typing import Literal
3+
4+
from pytensor.scan import scan
5+
from pytensor.tensor import TensorLike
6+
from pytensor.tensor.basic import (
7+
TensorVariable,
8+
arange,
9+
as_tensor,
10+
moveaxis,
11+
switch,
12+
zeros,
13+
)
14+
from pytensor.tensor.extra_ops import broadcast_to, linspace
15+
from pytensor.tensor.math import divmod as pt_divmod
16+
from pytensor.tensor.math import eq, mean, minimum
17+
from pytensor.tensor.math import max as pt_max
18+
from pytensor.tensor.math import min as pt_min
19+
from pytensor.tensor.shape import specify_broadcastable
20+
from pytensor.tensor.subtensor import set_subtensor
21+
22+
23+
PadMode = Literal[
24+
"constant",
25+
"edge",
26+
"linear_ramp",
27+
"maximum",
28+
"minimum",
29+
"mean",
30+
"median",
31+
"wrap",
32+
"symmetric",
33+
"reflect",
34+
]
35+
stat_funcs = {"maximum": pt_max, "minimum": pt_min, "mean": mean}
36+
37+
38+
def _slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]:
39+
"""
40+
Construct tuple of slices to slice an array in the given dimension.
41+
42+
Copied from numpy.lib.arraypad._slice_at_axis
43+
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L33
44+
45+
Parameters
46+
----------
47+
sl : slice
48+
The slice for the given dimension.
49+
axis : int
50+
The axis to which `sl` is applied. All other dimensions are left
51+
"unsliced".
52+
53+
Returns
54+
-------
55+
sl : tuple of slices
56+
A tuple with slices matching `shape` in length.
57+
58+
Examples
59+
--------
60+
>>> _slice_at_axis(slice(None, 3, -1), 1)
61+
(slice(None, None, None), slice(None, 3, -1), (...,))
62+
"""
63+
return (slice(None),) * axis + (sl,) + (...,) # type: ignore
64+
65+
66+
def _get_edges(
67+
padded: TensorVariable, axis: int, width_pair: tuple[TensorVariable, TensorVariable]
68+
) -> tuple[TensorVariable, TensorVariable]:
69+
"""
70+
Retrieve edge values from empty-padded array in given dimension.
71+
72+
Copied from numpy.lib.arraypad._get_edges
73+
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L154
74+
75+
Parameters
76+
----------
77+
padded : TensorVariable
78+
Empty-padded array.
79+
axis : int
80+
Dimension in which the edges are considered.
81+
width_pair : (TensorVariable, TensorVariable)
82+
Pair of widths that mark the pad area on both sides in the given
83+
dimension.
84+
85+
Returns
86+
-------
87+
left_edge, right_edge : TensorVariable
88+
Edge values of the valid area in `padded` in the given dimension. Its
89+
shape will always match `padded` except for the dimension given by
90+
`axis` which will have a length of 1.
91+
"""
92+
left_index = width_pair[0]
93+
left_slice = _slice_at_axis(slice(left_index, left_index + 1), axis)
94+
left_edge = padded[left_slice]
95+
96+
right_index = padded.shape[axis] - width_pair[1]
97+
right_slice = _slice_at_axis(slice(right_index - 1, right_index), axis)
98+
right_edge = padded[right_slice]
99+
100+
return left_edge, right_edge
101+
102+
103+
def _symbolic_pad(
104+
x: TensorVariable, pad_width: TensorVariable
105+
) -> tuple[TensorVariable, tuple[slice, ...], TensorVariable]:
106+
pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2)))
107+
new_shape = as_tensor(
108+
[pad_width[i][0] + size + pad_width[i][1] for i, size in enumerate(x.shape)]
109+
)
110+
original_area_slice = tuple(
111+
slice(pad_width[i][0], pad_width[i][0] + size) for i, size in enumerate(x.shape)
112+
)
113+
padded: TensorVariable = set_subtensor(zeros(new_shape)[original_area_slice], x)
114+
return padded, original_area_slice, pad_width
115+
116+
117+
def _get_padding_slices(
118+
dim_shape: TensorVariable,
119+
width_pair: tuple[TensorVariable, TensorVariable],
120+
axis: int,
121+
) -> tuple[tuple[slice, ...], tuple[slice, ...]]:
122+
left_slice = _slice_at_axis(slice(None, width_pair[0]), axis)
123+
right_slice = _slice_at_axis(slice(dim_shape - width_pair[1], None), axis)
124+
125+
return left_slice, right_slice
126+
127+
128+
def _constant_pad(
129+
x: TensorVariable, pad_width: TensorVariable, constant_values: TensorVariable
130+
) -> TensorVariable:
131+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
132+
values = broadcast_to(constant_values, as_tensor((padded.ndim, 2)))
133+
134+
for axis in range(padded.ndim):
135+
width_pair = pad_width[axis]
136+
value_pair = values[axis]
137+
dim_shape = padded.shape[axis]
138+
139+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
140+
padded = set_subtensor(padded[left_slice], value_pair[0])
141+
padded = set_subtensor(padded[right_slice], value_pair[1])
142+
143+
return padded
144+
145+
146+
def _edge_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
147+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
148+
for axis in range(padded.ndim):
149+
width_pair = pad_width[axis]
150+
dim_shape = padded.shape[axis]
151+
152+
left_edge, right_edge = _get_edges(padded, axis, width_pair)
153+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
154+
155+
padded = set_subtensor(padded[left_slice], left_edge)
156+
padded = set_subtensor(padded[right_slice], right_edge)
157+
158+
return padded
159+
160+
161+
def _get_stats(
162+
padded: TensorVariable,
163+
axis: int,
164+
width_pair: TensorVariable,
165+
length_pair: tuple[TensorVariable, TensorVariable] | tuple[None, None],
166+
stat_func: Callable,
167+
):
168+
"""
169+
Calculate statistic for the empty-padded array in given dimension.
170+
171+
Copied from numpy.lib.arraypad._get_stats
172+
https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L230
173+
174+
Parameters
175+
----------
176+
padded : TensorVariable
177+
Empty-padded array.
178+
axis : int
179+
Dimension in which the statistic is calculated.
180+
width_pair : (TensorVariable, TensorVariable)
181+
Pair of widths that mark the pad area on both sides in the given dimension.
182+
length_pair : 2-element sequence of None or TensorVariable
183+
Gives the number of values in valid area from each side that is taken into account when calculating the
184+
statistic. If None the entire valid area in `padded` is considered.
185+
stat_func : function
186+
Function to compute statistic. The expected signature is
187+
``stat_func(x: TensorVariable, axis: int, keepdims: bool) -> TensorVariable``.
188+
189+
Returns
190+
-------
191+
left_stat, right_stat : TensorVariable
192+
Calculated statistic for both sides of `padded`.
193+
"""
194+
# Calculate indices of the edges of the area with original values
195+
left_index = width_pair[0]
196+
right_index = padded.shape[axis] - width_pair[1]
197+
# as well as its length
198+
max_length = right_index - left_index
199+
200+
# Limit stat_lengths to max_length
201+
left_length, right_length = length_pair
202+
203+
# Calculate statistic for the left side
204+
left_length = (
205+
minimum(left_length, max_length) if left_length is not None else max_length
206+
)
207+
left_slice = _slice_at_axis(slice(left_index, left_index + left_length), axis)
208+
left_chunk = padded[left_slice]
209+
left_stat = stat_func(left_chunk, axis=axis, keepdims=True)
210+
if left_length is None and right_length is None:
211+
# We could also return early in the more general case of left_length == right_length, but we don't necessarily
212+
# know these shapes.
213+
# TODO: Add rewrite to simplify in this case
214+
return left_stat, left_stat
215+
216+
# Calculate statistic for the right side
217+
right_length = (
218+
minimum(right_length, max_length) if right_length is not None else max_length
219+
)
220+
right_slice = _slice_at_axis(slice(right_index - right_length, right_index), axis)
221+
right_chunk = padded[right_slice]
222+
right_stat = stat_func(right_chunk, axis=axis, keepdims=True)
223+
224+
return left_stat, right_stat
225+
226+
227+
def _stat_pad(
228+
x: TensorVariable, pad_width: TensorVariable, stat_func, stat_length=None
229+
):
230+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
231+
if stat_length is None:
232+
stat_length = [[None, None]] * padded.ndim
233+
else:
234+
stat_length = broadcast_to(stat_length, as_tensor((padded.ndim, 2)))
235+
236+
for axis in range(padded.ndim):
237+
width_pair = pad_width[axis]
238+
length_pair = stat_length[axis]
239+
dim_shape = padded.shape[axis]
240+
241+
left_stat, right_stat = _get_stats(
242+
padded, axis, width_pair, length_pair, stat_func
243+
)
244+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
245+
padded = set_subtensor(padded[left_slice], left_stat)
246+
padded = set_subtensor(padded[right_slice], right_stat)
247+
248+
return padded
249+
250+
251+
def _linear_ramp_pad(
252+
x: TensorVariable, pad_width: TensorVariable, end_values: TensorVariable | int = 0
253+
) -> TensorVariable:
254+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
255+
end_values = as_tensor(end_values)
256+
end_values = broadcast_to(end_values, as_tensor((padded.ndim, 2)))
257+
258+
for axis in range(padded.ndim):
259+
width_pair = pad_width[axis]
260+
end_value_pair = end_values[axis]
261+
edge_pair = _get_edges(padded, axis, width_pair)
262+
dim_shape = padded.shape[axis]
263+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
264+
265+
left_ramp, right_ramp = (
266+
linspace(
267+
start=end_value,
268+
stop=specify_broadcastable(edge, axis).squeeze(axis),
269+
num=width,
270+
endpoint=False,
271+
dtype=padded.dtype,
272+
axis=axis,
273+
)
274+
for end_value, edge, width in zip(end_value_pair, edge_pair, width_pair)
275+
)
276+
277+
# Reverse the direction of the ramp for the "right" side
278+
right_ramp = right_ramp[_slice_at_axis(slice(None, None, -1), axis)] # type: ignore
279+
280+
padded = set_subtensor(padded[left_slice], left_ramp)
281+
padded = set_subtensor(padded[right_slice], right_ramp)
282+
283+
return padded
284+
285+
286+
def flip(x, axis=None):
287+
if axis is None:
288+
index = ((slice(None, None, -1)),) * x.ndim
289+
else:
290+
if isinstance(axis, int):
291+
axis = [axis]
292+
index = [
293+
slice(None, None, -1) if i in axis else slice(None, None, None)
294+
for i in range(x.ndim)
295+
]
296+
return x[index]
297+
298+
299+
def _looping_pad(
300+
x: TensorVariable, pad_width: TensorVariable, kind: str
301+
) -> TensorVariable:
302+
pad_width = broadcast_to(pad_width, as_tensor((x.ndim, 2)))
303+
304+
for axis in range(x.ndim):
305+
if kind == "wrap":
306+
307+
def inner_func(i, x):
308+
return x
309+
310+
elif kind == "symmetric":
311+
# Delay creation of this function to here because we want to use the axis global inside the scan
312+
def inner_func(i, x):
313+
return switch(eq(i % 2, 0), flip(x, axis=axis), x)
314+
315+
size = x.shape[axis]
316+
repeats, (left_remainder, right_remainder) = pt_divmod(pad_width[axis], size)
317+
318+
left_trim = size - left_remainder
319+
right_trim = size - right_remainder
320+
total_repeats = repeats.sum() + 3 # left, right, center
321+
322+
parts, _ = scan(inner_func, non_sequences=[x], sequences=arange(total_repeats))
323+
324+
parts = moveaxis(parts, 0, axis)
325+
new_shape = [-1 if i == axis else x.shape[i] for i in range(x.ndim)]
326+
x = parts.reshape(new_shape)
327+
trim_slice = _slice_at_axis(slice(left_trim, -right_trim), axis)
328+
x = x[trim_slice]
329+
330+
return x
331+
332+
333+
def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwargs):
334+
allowed_kwargs = {
335+
"edge": [],
336+
"wrap": [],
337+
"constant": ["constant_values"],
338+
"linear_ramp": ["end_values"],
339+
"maximum": ["stat_length"],
340+
"mean": ["stat_length"],
341+
"median": ["stat_length"],
342+
"minimum": ["stat_length"],
343+
"reflect": ["reflect_type"],
344+
"symmetric": ["reflect_type"],
345+
}
346+
347+
if any(value not in allowed_kwargs[mode] for value in kwargs.keys()):
348+
raise ValueError(
349+
f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}"
350+
)
351+
x = as_tensor(x)
352+
pad_width = as_tensor(pad_width)
353+
354+
if mode == "constant":
355+
constant_values = as_tensor(kwargs.pop("constant_values", 0))
356+
return _constant_pad(x, pad_width, constant_values)
357+
elif mode == "edge":
358+
return _edge_pad(x, pad_width)
359+
elif mode in ["maximum", "minimum", "mean", "median"]:
360+
if mode == "median":
361+
# TODO: pt.quantile? pt.median?
362+
raise NotImplementedError("Median padding not implemented")
363+
stat_func = stat_funcs[mode]
364+
return _stat_pad(x, pad_width, stat_func, **kwargs)
365+
elif mode == "linear_ramp":
366+
end_values = kwargs.pop("end_values", 0)
367+
return _linear_ramp_pad(x, pad_width, end_values)
368+
elif mode == "wrap":
369+
return _looping_pad(x, pad_width, kind="wrap")
370+
elif mode == "symmetric":
371+
reflect_type = kwargs.pop("reflect_type", "even")
372+
if reflect_type == "odd":
373+
raise NotImplementedError("Odd reflection not implemented")
374+
return _looping_pad(x, pad_width, kind="symmetric")
375+
elif mode == "reflect":
376+
reflect_type = kwargs.pop("reflect_type", "even")
377+
if reflect_type == "odd":
378+
raise NotImplementedError("Odd reflection not implemented")
379+
raise NotImplementedError("Reflect padding not implemented")
380+
else:
381+
raise ValueError(f"Invalid mode: {mode}")
382+
383+
384+
__all__ = ["pad"]

0 commit comments

Comments
 (0)