1- from numpy . lib . arraypad import _get_edges , _slice_at_axis # noqa
2-
3- from pytensor . tensor . basic import (
4- TensorVariable ,
5- as_tensor ,
6- swapaxes ,
7- zeros ,
8- )
9- from pytensor .tensor .extra_ops import linspace , broadcast_to
1+ from collections . abc import Callable
2+ from typing import Literal
3+
4+ from pytensor . tensor import TensorLike
5+ from pytensor . tensor . basic import TensorVariable , as_tensor , zeros
6+ from pytensor . tensor . extra_ops import broadcast_to , linspace
7+ from pytensor . tensor . math import max as pt_max
8+ from pytensor . tensor . math import mean , minimum
9+ from pytensor .tensor .math import min as pt_min
1010from pytensor .tensor .shape import specify_broadcastable
1111from pytensor .tensor .subtensor import set_subtensor
1212
1313
14+ PadMode = Literal [
15+ "constant" , "edge" , "linear_ramp" , "maximum" , "minimum" , "mean" , "median"
16+ ]
17+ stat_funcs = {"maximum" : pt_max , "minimum" : pt_min , "mean" : mean }
18+
19+
20+ def _slice_at_axis (sl : slice , axis : int ) -> tuple [slice , ...]:
21+ """
22+ Construct tuple of slices to slice an array in the given dimension.
23+
24+ Copied from numpy.lib.arraypad._slice_at_axis
25+ https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L33
26+
27+ Parameters
28+ ----------
29+ sl : slice
30+ The slice for the given dimension.
31+ axis : int
32+ The axis to which `sl` is applied. All other dimensions are left
33+ "unsliced".
34+
35+ Returns
36+ -------
37+ sl : tuple of slices
38+ A tuple with slices matching `shape` in length.
39+
40+ Examples
41+ --------
42+ >>> _slice_at_axis(slice(None, 3, -1), 1)
43+ (slice(None, None, None), slice(None, 3, -1), (...,))
44+ """
45+ return (slice (None ),) * axis + (sl ,) + (...,) # type: ignore
46+
47+
48+ def _get_edges (
49+ padded : TensorVariable , axis : int , width_pair : tuple [TensorVariable , TensorVariable ]
50+ ) -> tuple [TensorVariable , TensorVariable ]:
51+ """
52+ Retrieve edge values from empty-padded array in given dimension.
53+
54+ Copied from numpy.lib.arraypad._get_edges
55+ https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L154
56+
57+ Parameters
58+ ----------
59+ padded : TensorVariable
60+ Empty-padded array.
61+ axis : int
62+ Dimension in which the edges are considered.
63+ width_pair : (TensorVariable, TensorVariable)
64+ Pair of widths that mark the pad area on both sides in the given
65+ dimension.
66+
67+ Returns
68+ -------
69+ left_edge, right_edge : TensorVariable
70+ Edge values of the valid area in `padded` in the given dimension. Its
71+ shape will always match `padded` except for the dimension given by
72+ `axis` which will have a length of 1.
73+ """
74+ left_index = width_pair [0 ]
75+ left_slice = _slice_at_axis (slice (left_index , left_index + 1 ), axis )
76+ left_edge = padded [left_slice ]
77+
78+ right_index = padded .shape [axis ] - width_pair [1 ]
79+ right_slice = _slice_at_axis (slice (right_index - 1 , right_index ), axis )
80+ right_edge = padded [right_slice ]
81+
82+ return left_edge , right_edge
83+
84+
1485def _symbolic_pad (
1586 x : TensorVariable , pad_width : TensorVariable
1687) -> tuple [TensorVariable , tuple [slice , ...], TensorVariable ]:
17- pad_width = broadcast_to (pad_width , ( x .ndim , 2 ))
88+ pad_width = broadcast_to (pad_width , as_tensor (( x .ndim , 2 ) ))
1889 new_shape = as_tensor (
1990 [pad_width [i ][0 ] + size + pad_width [i ][1 ] for i , size in enumerate (x .shape )]
2091 )
@@ -26,8 +97,10 @@ def _symbolic_pad(
2697
2798
2899def _get_padding_slices (
29- dim_shape : TensorVariable , width_pair : tuple [TensorVariable , TensorVariable ], axis : int
30- ):
100+ dim_shape : TensorVariable ,
101+ width_pair : tuple [TensorVariable , TensorVariable ],
102+ axis : int ,
103+ ) -> tuple [tuple [slice , ...], tuple [slice , ...]]:
31104 left_slice = _slice_at_axis (slice (None , width_pair [0 ]), axis )
32105 right_slice = _slice_at_axis (slice (dim_shape - width_pair [1 ], None ), axis )
33106
@@ -36,9 +109,9 @@ def _get_padding_slices(
36109
37110def _constant_pad (
38111 x : TensorVariable , pad_width : TensorVariable , constant_values : TensorVariable
39- ):
112+ ) -> TensorVariable :
40113 padded , area_slice , pad_width = _symbolic_pad (x , pad_width )
41- values = broadcast_to (constant_values , ( padded .ndim , 2 ))
114+ values = broadcast_to (constant_values , as_tensor (( padded .ndim , 2 ) ))
42115
43116 for axis in range (padded .ndim ):
44117 width_pair = pad_width [axis ]
@@ -52,7 +125,7 @@ def _constant_pad(
52125 return padded
53126
54127
55- def _edge_pad (x : TensorVariable , pad_width : TensorVariable ):
128+ def _edge_pad (x : TensorVariable , pad_width : TensorVariable ) -> TensorVariable :
56129 padded , area_slice , pad_width = _symbolic_pad (x , pad_width )
57130 for axis in range (padded .ndim ):
58131 width_pair = pad_width [axis ]
@@ -67,42 +140,133 @@ def _edge_pad(x: TensorVariable, pad_width: TensorVariable):
67140 return padded
68141
69142
143+ def _get_stats (
144+ padded : TensorVariable ,
145+ axis : int ,
146+ width_pair : TensorVariable ,
147+ length_pair : tuple [TensorVariable , TensorVariable ] | tuple [None , None ],
148+ stat_func : Callable ,
149+ ):
150+ """
151+ Calculate statistic for the empty-padded array in given dimension.
152+
153+ Copied from numpy.lib.arraypad._get_stats
154+ https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L230
155+
156+ Parameters
157+ ----------
158+ padded : TensorVariable
159+ Empty-padded array.
160+ axis : int
161+ Dimension in which the statistic is calculated.
162+ width_pair : (TensorVariable, TensorVariable)
163+ Pair of widths that mark the pad area on both sides in the given dimension.
164+ length_pair : 2-element sequence of None or TensorVariable
165+ Gives the number of values in valid area from each side that is taken into account when calculating the
166+ statistic. If None the entire valid area in `padded` is considered.
167+ stat_func : function
168+ Function to compute statistic. The expected signature is
169+ ``stat_func(x: TensorVariable, axis: int, keepdims: bool) -> TensorVariable``.
170+
171+ Returns
172+ -------
173+ left_stat, right_stat : TensorVariable
174+ Calculated statistic for both sides of `padded`.
175+ """
176+ # Calculate indices of the edges of the area with original values
177+ left_index = width_pair [0 ]
178+ right_index = padded .shape [axis ] - width_pair [1 ]
179+ # as well as its length
180+ max_length = right_index - left_index
181+
182+ # Limit stat_lengths to max_length
183+ left_length , right_length = length_pair
184+
185+ # Calculate statistic for the left side
186+ left_length = (
187+ minimum (left_length , max_length ) if left_length is not None else max_length
188+ )
189+ left_slice = _slice_at_axis (slice (left_index , left_index + left_length ), axis )
190+ left_chunk = padded [left_slice ]
191+ left_stat = stat_func (left_chunk , axis = axis , keepdims = True )
192+ if left_length is None and right_length is None :
193+ # We could also return early in the more general case of left_length == right_length, but we don't necessarily
194+ # know these shapes.
195+ # TODO: Add rewrite to simplify in this case
196+ return left_stat , left_stat
197+
198+ # Calculate statistic for the right side
199+ right_length = (
200+ minimum (right_length , max_length ) if right_length is not None else max_length
201+ )
202+ right_slice = _slice_at_axis (slice (right_index - right_length , right_index ), axis )
203+ right_chunk = padded [right_slice ]
204+ right_stat = stat_func (right_chunk , axis = axis , keepdims = True )
205+
206+ return left_stat , right_stat
207+
208+
209+ def _stat_pad (
210+ x : TensorVariable , pad_width : TensorVariable , stat_func , stat_length = None
211+ ):
212+ padded , area_slice , pad_width = _symbolic_pad (x , pad_width )
213+ if stat_length is None :
214+ stat_length = [[None , None ]] * padded .ndim
215+ else :
216+ stat_length = broadcast_to (stat_length , as_tensor ((padded .ndim , 2 )))
217+
218+ for axis in range (padded .ndim ):
219+ width_pair = pad_width [axis ]
220+ length_pair = stat_length [axis ]
221+ dim_shape = padded .shape [axis ]
222+
223+ left_stat , right_stat = _get_stats (
224+ padded , axis , width_pair , length_pair , stat_func
225+ )
226+ left_slice , right_slice = _get_padding_slices (dim_shape , width_pair , axis )
227+ padded = set_subtensor (padded [left_slice ], left_stat )
228+ padded = set_subtensor (padded [right_slice ], right_stat )
229+
230+ return padded
231+
232+
70233def _linear_ramp_pad (
71234 x : TensorVariable , pad_width : TensorVariable , end_values : TensorVariable | int = 0
72- ):
235+ ) -> TensorVariable :
73236 padded , area_slice , pad_width = _symbolic_pad (x , pad_width )
74- end_values = broadcast_to (end_values , (padded .ndim , 2 ))
237+ end_values = as_tensor (end_values )
238+ end_values = broadcast_to (end_values , as_tensor ((padded .ndim , 2 )))
239+
75240 for axis in range (padded .ndim ):
76241 width_pair = pad_width [axis ]
77242 end_value_pair = end_values [axis ]
78243 edge_pair = _get_edges (padded , axis , width_pair )
79244 dim_shape = padded .shape [axis ]
80245 left_slice , right_slice = _get_padding_slices (dim_shape , width_pair , axis )
81246
82- # pt.linspace doesn't have the endpoint kwarg, so need to take one extra step then slice it away
83- left_ramp = linspace (
84- start = end_value_pair [ 0 ] ,
85- end = specify_broadcastable (edge_pair [ 0 ] , axis ).squeeze (axis ),
86- steps = width_pair [ 0 ] + 1 ,
87- )[: - 1 ]
88- right_ramp = linspace (
89- start = end_value_pair [ 1 ] ,
90- end = specify_broadcastable ( edge_pair [ 1 ], axis ). squeeze ( axis ),
91- steps = width_pair [ 1 ] + 1 ,
92- )[: - 1 ]
93- right_ramp = right_ramp [ _slice_at_axis ( slice ( None , None , - 1 ), axis )]
94-
95- # FIXME: This swapaxes is needed because the shapes of the linspaces don't "rotate" with
96- # the different dimensions. But this makes the non-active dimensions backwards in the padding.
97- padded = set_subtensor (padded [left_slice ], swapaxes ( left_ramp , 0 , axis ) )
98- padded = set_subtensor (padded [right_slice ], swapaxes ( right_ramp , 0 , axis ) )
247+ left_ramp , right_ramp = (
248+ linspace (
249+ start = end_value ,
250+ stop = specify_broadcastable (edge , axis ).squeeze (axis ),
251+ num = width ,
252+ endpoint = False ,
253+ dtype = padded . dtype ,
254+ axis = axis ,
255+ )
256+ for end_value , edge , width in zip ( end_value_pair , edge_pair , width_pair )
257+ )
258+
259+ # Reverse the direction of the ramp for the "right" side
260+ right_ramp = right_ramp [ _slice_at_axis ( slice ( None , None , - 1 ), axis )] # type: ignore
261+
262+ padded = set_subtensor (padded [left_slice ], left_ramp )
263+ padded = set_subtensor (padded [right_slice ], right_ramp )
99264
100265 return padded
101266
102267
103- def pad (x , pad_width , mode = "constant" , ** kwargs ):
268+ def pad (x : TensorLike , pad_width : TensorLike , mode : PadMode = "constant" , ** kwargs ):
104269 allowed_kwargs = {
105- "empty" : [],
106270 "edge" : [],
107271 "wrap" : [],
108272 "constant" : ["constant_values" ],
@@ -115,16 +279,24 @@ def pad(x, pad_width, mode="constant", **kwargs):
115279 "symmetric" : ["reflect_type" ],
116280 }
117281
118- if any (value not in allowed_kwargs [mode ] for value in kwargs .values ()):
282+ if any (value not in allowed_kwargs [mode ] for value in kwargs .keys ()):
119283 raise ValueError (
120284 f"Invalid keyword arguments for mode '{ mode } ': { kwargs .keys ()} "
121285 )
286+ x = as_tensor (x )
287+ pad_width = as_tensor (pad_width )
122288
123289 if mode == "constant" :
124- constant_values = kwargs .pop ("constant_values" , 0 )
290+ constant_values = as_tensor ( kwargs .pop ("constant_values" , 0 ) )
125291 return _constant_pad (x , pad_width , constant_values )
126292 elif mode == "edge" :
127293 return _edge_pad (x , pad_width )
294+ elif mode in ["maximum" , "minimum" , "mean" , "median" ]:
295+ if mode == "median" :
296+ # TODO: pt.quantile? pt.median?
297+ raise NotImplementedError ("Median padding not implemented" )
298+ stat_func = stat_funcs [mode ]
299+ return _stat_pad (x , pad_width , stat_func , ** kwargs )
128300 elif mode == "linear_ramp" :
129301 end_values = kwargs .pop ("end_values" , 0 )
130302 return _linear_ramp_pad (x , pad_width , end_values )
0 commit comments