Skip to content

Commit f3f027d

Browse files
Basic padding functionality
1 parent b8e26cd commit f3f027d

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed

pytensor/tensor/basic.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
from numpy.core.multiarray import normalize_axis_index
1818
from numpy.core.numeric import normalize_axis_tuple
19+
from numpy.lib.arraypad import _get_edges, _slice_at_axis
1920

2021
import pytensor
2122
import pytensor.scalar.sharedvar
@@ -50,6 +51,7 @@
5051
scalar_elemwise,
5152
)
5253
from pytensor.tensor.exceptions import NotScalarConstantError
54+
from pytensor.tensor.extra_ops import broadcast_to, linspace
5355
from pytensor.tensor.shape import (
5456
Shape,
5557
Shape_i,
@@ -61,6 +63,7 @@
6163
shape_tuple,
6264
specify_broadcastable,
6365
)
66+
from pytensor.tensor.subtensor import set_subtensor
6467
from pytensor.tensor.type import (
6568
TensorType,
6669
discrete_dtypes,
@@ -4342,6 +4345,125 @@ def ix_(*args):
43424345
return tuple(out)
43434346

43444347

4348+
def _symbolic_pad(
4349+
x: TensorVariable, pad_width: TensorVariable
4350+
) -> tuple[TensorVariable, tuple[slice, ...], TensorVariable]:
4351+
pad_width = broadcast_to(pad_width, (x.ndim, 2))
4352+
new_shape = as_tensor(
4353+
[pad_width[i][0] + size + pad_width[i][1] for i, size in enumerate(x.shape)]
4354+
)
4355+
original_area_slice = tuple(
4356+
slice(pad_width[i][0], pad_width[i][0] + size) for i, size in enumerate(x.shape)
4357+
)
4358+
padded: TensorVariable = set_subtensor(zeros(new_shape)[original_area_slice], x)
4359+
return padded, original_area_slice, pad_width
4360+
4361+
4362+
def _get_padding_slices(
4363+
dim_shape: TensorVariable, width_pair: tuple[TensorVariable], axis: int
4364+
):
4365+
left_slice = _slice_at_axis(slice(None, width_pair[0]), axis)
4366+
right_slice = _slice_at_axis(slice(dim_shape - width_pair[1], None), axis)
4367+
4368+
return left_slice, right_slice
4369+
4370+
4371+
def _constant_pad(
4372+
x: TensorVariable, pad_width: TensorVariable, constant_values: TensorVariable
4373+
):
4374+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
4375+
values = broadcast_to(constant_values, (padded.ndim, 2))
4376+
4377+
for axis in range(padded.ndim):
4378+
width_pair = pad_width[axis]
4379+
value_pair = values[axis]
4380+
dim_shape = padded.shape[axis]
4381+
4382+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
4383+
padded = set_subtensor(padded[left_slice], value_pair[0])
4384+
padded = set_subtensor(padded[right_slice], value_pair[1])
4385+
4386+
return padded
4387+
4388+
4389+
def _edge_pad(x: TensorVariable, pad_width: TensorVariable):
4390+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
4391+
for axis in range(padded.ndim):
4392+
width_pair = pad_width[axis]
4393+
dim_shape = padded.shape[axis]
4394+
4395+
left_edge, right_edge = _get_edges(padded, axis, width_pair)
4396+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
4397+
4398+
padded = set_subtensor(padded[left_slice], left_edge)
4399+
padded = set_subtensor(padded[right_slice], right_edge)
4400+
4401+
return padded
4402+
4403+
4404+
def _linear_ramp_pad(
4405+
x: TensorVariable, pad_width: TensorVariable, end_values: TensorVariable = 0
4406+
):
4407+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
4408+
end_values = broadcast_to(end_values, (padded.ndim, 2))
4409+
for axis in range(padded.ndim):
4410+
width_pair = pad_width[axis]
4411+
end_value_pair = end_values[axis]
4412+
edge_pair = _get_edges(padded, axis, width_pair)
4413+
dim_shape = padded.shape[axis]
4414+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
4415+
4416+
# pt.linspace doesn't have the endpoint kwarg, so need to take one extra step then slice it away
4417+
left_ramp = linspace(
4418+
start=end_value_pair[0],
4419+
end=specify_broadcastable(edge_pair[0], axis).squeeze(axis),
4420+
steps=width_pair[0] + 1,
4421+
)[:-1]
4422+
right_ramp = linspace(
4423+
start=end_value_pair[1],
4424+
end=specify_broadcastable(edge_pair[1], axis).squeeze(axis),
4425+
steps=width_pair[1] + 1,
4426+
)[:-1]
4427+
right_ramp = right_ramp[_slice_at_axis(slice(None, None, -1), axis)]
4428+
4429+
# FIXME: This swapaxes is needed because the shapes of the linspaces don't "rotate" with
4430+
# the different dimensions. But this makes the non-active dimensions backwards in the padding.
4431+
padded = set_subtensor(padded[left_slice], swapaxes(left_ramp, 0, axis))
4432+
padded = set_subtensor(padded[right_slice], swapaxes(right_ramp, 0, axis))
4433+
4434+
return padded
4435+
4436+
4437+
def pad(x, pad_width, mode="constant", **kwargs):
4438+
allowed_kwargs = {
4439+
"empty": [],
4440+
"edge": [],
4441+
"wrap": [],
4442+
"constant": ["constant_values"],
4443+
"linear_ramp": ["end_values"],
4444+
"maximum": ["stat_length"],
4445+
"mean": ["stat_length"],
4446+
"median": ["stat_length"],
4447+
"minimum": ["stat_length"],
4448+
"reflect": ["reflect_type"],
4449+
"symmetric": ["reflect_type"],
4450+
}
4451+
4452+
if any(value not in allowed_kwargs[mode] for value in kwargs.values()):
4453+
raise ValueError(
4454+
f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}"
4455+
)
4456+
4457+
if mode == "constant":
4458+
constant_values = kwargs.pop("constant_values", 0)
4459+
return _constant_pad(x, pad_width, constant_values)
4460+
elif mode == "edge":
4461+
return _edge_pad(x, pad_width)
4462+
elif mode == "linear_ramp":
4463+
end_values = kwargs.pop("end_values", 0)
4464+
return _linear_ramp_pad(x, pad_width, end_values)
4465+
4466+
43454467
__all__ = [
43464468
"take_along_axis",
43474469
"expand_dims",

0 commit comments

Comments
 (0)