|
16 | 16 | import numpy as np
|
17 | 17 | from numpy.core.multiarray import normalize_axis_index
|
18 | 18 | from numpy.core.numeric import normalize_axis_tuple
|
| 19 | +from numpy.lib.arraypad import _get_edges, _slice_at_axis |
19 | 20 |
|
20 | 21 | import pytensor
|
21 | 22 | import pytensor.scalar.sharedvar
|
|
50 | 51 | scalar_elemwise,
|
51 | 52 | )
|
52 | 53 | from pytensor.tensor.exceptions import NotScalarConstantError
|
| 54 | +from pytensor.tensor.extra_ops import broadcast_to, linspace |
53 | 55 | from pytensor.tensor.shape import (
|
54 | 56 | Shape,
|
55 | 57 | Shape_i,
|
|
61 | 63 | shape_tuple,
|
62 | 64 | specify_broadcastable,
|
63 | 65 | )
|
| 66 | +from pytensor.tensor.subtensor import set_subtensor |
64 | 67 | from pytensor.tensor.type import (
|
65 | 68 | TensorType,
|
66 | 69 | discrete_dtypes,
|
@@ -4342,6 +4345,125 @@ def ix_(*args):
|
4342 | 4345 | return tuple(out)
|
4343 | 4346 |
|
4344 | 4347 |
|
| 4348 | +def _symbolic_pad( |
| 4349 | + x: TensorVariable, pad_width: TensorVariable |
| 4350 | +) -> tuple[TensorVariable, tuple[slice, ...], TensorVariable]: |
| 4351 | + pad_width = broadcast_to(pad_width, (x.ndim, 2)) |
| 4352 | + new_shape = as_tensor( |
| 4353 | + [pad_width[i][0] + size + pad_width[i][1] for i, size in enumerate(x.shape)] |
| 4354 | + ) |
| 4355 | + original_area_slice = tuple( |
| 4356 | + slice(pad_width[i][0], pad_width[i][0] + size) for i, size in enumerate(x.shape) |
| 4357 | + ) |
| 4358 | + padded: TensorVariable = set_subtensor(zeros(new_shape)[original_area_slice], x) |
| 4359 | + return padded, original_area_slice, pad_width |
| 4360 | + |
| 4361 | + |
| 4362 | +def _get_padding_slices( |
| 4363 | + dim_shape: TensorVariable, width_pair: tuple[TensorVariable], axis: int |
| 4364 | +): |
| 4365 | + left_slice = _slice_at_axis(slice(None, width_pair[0]), axis) |
| 4366 | + right_slice = _slice_at_axis(slice(dim_shape - width_pair[1], None), axis) |
| 4367 | + |
| 4368 | + return left_slice, right_slice |
| 4369 | + |
| 4370 | + |
| 4371 | +def _constant_pad( |
| 4372 | + x: TensorVariable, pad_width: TensorVariable, constant_values: TensorVariable |
| 4373 | +): |
| 4374 | + padded, area_slice, pad_width = _symbolic_pad(x, pad_width) |
| 4375 | + values = broadcast_to(constant_values, (padded.ndim, 2)) |
| 4376 | + |
| 4377 | + for axis in range(padded.ndim): |
| 4378 | + width_pair = pad_width[axis] |
| 4379 | + value_pair = values[axis] |
| 4380 | + dim_shape = padded.shape[axis] |
| 4381 | + |
| 4382 | + left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis) |
| 4383 | + padded = set_subtensor(padded[left_slice], value_pair[0]) |
| 4384 | + padded = set_subtensor(padded[right_slice], value_pair[1]) |
| 4385 | + |
| 4386 | + return padded |
| 4387 | + |
| 4388 | + |
| 4389 | +def _edge_pad(x: TensorVariable, pad_width: TensorVariable): |
| 4390 | + padded, area_slice, pad_width = _symbolic_pad(x, pad_width) |
| 4391 | + for axis in range(padded.ndim): |
| 4392 | + width_pair = pad_width[axis] |
| 4393 | + dim_shape = padded.shape[axis] |
| 4394 | + |
| 4395 | + left_edge, right_edge = _get_edges(padded, axis, width_pair) |
| 4396 | + left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis) |
| 4397 | + |
| 4398 | + padded = set_subtensor(padded[left_slice], left_edge) |
| 4399 | + padded = set_subtensor(padded[right_slice], right_edge) |
| 4400 | + |
| 4401 | + return padded |
| 4402 | + |
| 4403 | + |
| 4404 | +def _linear_ramp_pad( |
| 4405 | + x: TensorVariable, pad_width: TensorVariable, end_values: TensorVariable = 0 |
| 4406 | +): |
| 4407 | + padded, area_slice, pad_width = _symbolic_pad(x, pad_width) |
| 4408 | + end_values = broadcast_to(end_values, (padded.ndim, 2)) |
| 4409 | + for axis in range(padded.ndim): |
| 4410 | + width_pair = pad_width[axis] |
| 4411 | + end_value_pair = end_values[axis] |
| 4412 | + edge_pair = _get_edges(padded, axis, width_pair) |
| 4413 | + dim_shape = padded.shape[axis] |
| 4414 | + left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis) |
| 4415 | + |
| 4416 | + # pt.linspace doesn't have the endpoint kwarg, so need to take one extra step then slice it away |
| 4417 | + left_ramp = linspace( |
| 4418 | + start=end_value_pair[0], |
| 4419 | + end=specify_broadcastable(edge_pair[0], axis).squeeze(axis), |
| 4420 | + steps=width_pair[0] + 1, |
| 4421 | + )[:-1] |
| 4422 | + right_ramp = linspace( |
| 4423 | + start=end_value_pair[1], |
| 4424 | + end=specify_broadcastable(edge_pair[1], axis).squeeze(axis), |
| 4425 | + steps=width_pair[1] + 1, |
| 4426 | + )[:-1] |
| 4427 | + right_ramp = right_ramp[_slice_at_axis(slice(None, None, -1), axis)] |
| 4428 | + |
| 4429 | + # FIXME: This swapaxes is needed because the shapes of the linspaces don't "rotate" with |
| 4430 | + # the different dimensions. But this makes the non-active dimensions backwards in the padding. |
| 4431 | + padded = set_subtensor(padded[left_slice], swapaxes(left_ramp, 0, axis)) |
| 4432 | + padded = set_subtensor(padded[right_slice], swapaxes(right_ramp, 0, axis)) |
| 4433 | + |
| 4434 | + return padded |
| 4435 | + |
| 4436 | + |
| 4437 | +def pad(x, pad_width, mode="constant", **kwargs): |
| 4438 | + allowed_kwargs = { |
| 4439 | + "empty": [], |
| 4440 | + "edge": [], |
| 4441 | + "wrap": [], |
| 4442 | + "constant": ["constant_values"], |
| 4443 | + "linear_ramp": ["end_values"], |
| 4444 | + "maximum": ["stat_length"], |
| 4445 | + "mean": ["stat_length"], |
| 4446 | + "median": ["stat_length"], |
| 4447 | + "minimum": ["stat_length"], |
| 4448 | + "reflect": ["reflect_type"], |
| 4449 | + "symmetric": ["reflect_type"], |
| 4450 | + } |
| 4451 | + |
| 4452 | + if any(value not in allowed_kwargs[mode] for value in kwargs.values()): |
| 4453 | + raise ValueError( |
| 4454 | + f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}" |
| 4455 | + ) |
| 4456 | + |
| 4457 | + if mode == "constant": |
| 4458 | + constant_values = kwargs.pop("constant_values", 0) |
| 4459 | + return _constant_pad(x, pad_width, constant_values) |
| 4460 | + elif mode == "edge": |
| 4461 | + return _edge_pad(x, pad_width) |
| 4462 | + elif mode == "linear_ramp": |
| 4463 | + end_values = kwargs.pop("end_values", 0) |
| 4464 | + return _linear_ramp_pad(x, pad_width, end_values) |
| 4465 | + |
| 4466 | + |
4345 | 4467 | __all__ = [
|
4346 | 4468 | "take_along_axis",
|
4347 | 4469 | "expand_dims",
|
|
0 commit comments