1
- from numpy . lib . arraypad import _get_edges , _slice_at_axis # noqa
2
-
3
- from pytensor . tensor . basic import (
4
- TensorVariable ,
5
- as_tensor ,
6
- swapaxes ,
7
- zeros ,
8
- )
9
- from pytensor .tensor .extra_ops import linspace , broadcast_to
1
+ from collections . abc import Callable
2
+ from typing import Literal
3
+
4
+ from pytensor . tensor import TensorLike
5
+ from pytensor . tensor . basic import TensorVariable , as_tensor , zeros
6
+ from pytensor . tensor . extra_ops import broadcast_to , linspace
7
+ from pytensor . tensor . math import max as pt_max
8
+ from pytensor . tensor . math import mean , minimum
9
+ from pytensor .tensor .math import min as pt_min
10
10
from pytensor .tensor .shape import specify_broadcastable
11
11
from pytensor .tensor .subtensor import set_subtensor
12
12
13
13
14
+ PadMode = Literal [
15
+ "constant" , "edge" , "linear_ramp" , "maximum" , "minimum" , "mean" , "median"
16
+ ]
17
+ stat_funcs = {"maximum" : pt_max , "minimum" : pt_min , "mean" : mean }
18
+
19
+
20
+ def _slice_at_axis (sl : slice , axis : int ) -> tuple [slice , ...]:
21
+ """
22
+ Construct tuple of slices to slice an array in the given dimension.
23
+
24
+ Copied from numpy.lib.arraypad._slice_at_axis
25
+ https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L33
26
+
27
+ Parameters
28
+ ----------
29
+ sl : slice
30
+ The slice for the given dimension.
31
+ axis : int
32
+ The axis to which `sl` is applied. All other dimensions are left
33
+ "unsliced".
34
+
35
+ Returns
36
+ -------
37
+ sl : tuple of slices
38
+ A tuple with slices matching `shape` in length.
39
+
40
+ Examples
41
+ --------
42
+ >>> _slice_at_axis(slice(None, 3, -1), 1)
43
+ (slice(None, None, None), slice(None, 3, -1), (...,))
44
+ """
45
+ return (slice (None ),) * axis + (sl ,) + (...,) # type: ignore
46
+
47
+
48
+ def _get_edges (
49
+ padded : TensorVariable , axis : int , width_pair : tuple [TensorVariable , TensorVariable ]
50
+ ) -> tuple [TensorVariable , TensorVariable ]:
51
+ """
52
+ Retrieve edge values from empty-padded array in given dimension.
53
+
54
+ Copied from numpy.lib.arraypad._get_edges
55
+ https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L154
56
+
57
+ Parameters
58
+ ----------
59
+ padded : TensorVariable
60
+ Empty-padded array.
61
+ axis : int
62
+ Dimension in which the edges are considered.
63
+ width_pair : (TensorVariable, TensorVariable)
64
+ Pair of widths that mark the pad area on both sides in the given
65
+ dimension.
66
+
67
+ Returns
68
+ -------
69
+ left_edge, right_edge : TensorVariable
70
+ Edge values of the valid area in `padded` in the given dimension. Its
71
+ shape will always match `padded` except for the dimension given by
72
+ `axis` which will have a length of 1.
73
+ """
74
+ left_index = width_pair [0 ]
75
+ left_slice = _slice_at_axis (slice (left_index , left_index + 1 ), axis )
76
+ left_edge = padded [left_slice ]
77
+
78
+ right_index = padded .shape [axis ] - width_pair [1 ]
79
+ right_slice = _slice_at_axis (slice (right_index - 1 , right_index ), axis )
80
+ right_edge = padded [right_slice ]
81
+
82
+ return left_edge , right_edge
83
+
84
+
14
85
def _symbolic_pad (
15
86
x : TensorVariable , pad_width : TensorVariable
16
87
) -> tuple [TensorVariable , tuple [slice , ...], TensorVariable ]:
17
- pad_width = broadcast_to (pad_width , ( x .ndim , 2 ))
88
+ pad_width = broadcast_to (pad_width , as_tensor (( x .ndim , 2 ) ))
18
89
new_shape = as_tensor (
19
90
[pad_width [i ][0 ] + size + pad_width [i ][1 ] for i , size in enumerate (x .shape )]
20
91
)
@@ -26,8 +97,10 @@ def _symbolic_pad(
26
97
27
98
28
99
def _get_padding_slices (
29
- dim_shape : TensorVariable , width_pair : tuple [TensorVariable , TensorVariable ], axis : int
30
- ):
100
+ dim_shape : TensorVariable ,
101
+ width_pair : tuple [TensorVariable , TensorVariable ],
102
+ axis : int ,
103
+ ) -> tuple [tuple [slice , ...], tuple [slice , ...]]:
31
104
left_slice = _slice_at_axis (slice (None , width_pair [0 ]), axis )
32
105
right_slice = _slice_at_axis (slice (dim_shape - width_pair [1 ], None ), axis )
33
106
@@ -36,9 +109,9 @@ def _get_padding_slices(
36
109
37
110
def _constant_pad (
38
111
x : TensorVariable , pad_width : TensorVariable , constant_values : TensorVariable
39
- ):
112
+ ) -> TensorVariable :
40
113
padded , area_slice , pad_width = _symbolic_pad (x , pad_width )
41
- values = broadcast_to (constant_values , ( padded .ndim , 2 ))
114
+ values = broadcast_to (constant_values , as_tensor (( padded .ndim , 2 ) ))
42
115
43
116
for axis in range (padded .ndim ):
44
117
width_pair = pad_width [axis ]
@@ -52,7 +125,7 @@ def _constant_pad(
52
125
return padded
53
126
54
127
55
- def _edge_pad (x : TensorVariable , pad_width : TensorVariable ):
128
+ def _edge_pad (x : TensorVariable , pad_width : TensorVariable ) -> TensorVariable :
56
129
padded , area_slice , pad_width = _symbolic_pad (x , pad_width )
57
130
for axis in range (padded .ndim ):
58
131
width_pair = pad_width [axis ]
@@ -67,42 +140,133 @@ def _edge_pad(x: TensorVariable, pad_width: TensorVariable):
67
140
return padded
68
141
69
142
143
+ def _get_stats (
144
+ padded : TensorVariable ,
145
+ axis : int ,
146
+ width_pair : TensorVariable ,
147
+ length_pair : tuple [TensorVariable , TensorVariable ] | tuple [None , None ],
148
+ stat_func : Callable ,
149
+ ):
150
+ """
151
+ Calculate statistic for the empty-padded array in given dimension.
152
+
153
+ Copied from numpy.lib.arraypad._get_stats
154
+ https://github.com/numpy/numpy/blob/300096d384046eee479b0c7a70f79e308da52bff/numpy/lib/_arraypad_impl.py#L230
155
+
156
+ Parameters
157
+ ----------
158
+ padded : TensorVariable
159
+ Empty-padded array.
160
+ axis : int
161
+ Dimension in which the statistic is calculated.
162
+ width_pair : (TensorVariable, TensorVariable)
163
+ Pair of widths that mark the pad area on both sides in the given dimension.
164
+ length_pair : 2-element sequence of None or TensorVariable
165
+ Gives the number of values in valid area from each side that is taken into account when calculating the
166
+ statistic. If None the entire valid area in `padded` is considered.
167
+ stat_func : function
168
+ Function to compute statistic. The expected signature is
169
+ ``stat_func(x: TensorVariable, axis: int, keepdims: bool) -> TensorVariable``.
170
+
171
+ Returns
172
+ -------
173
+ left_stat, right_stat : TensorVariable
174
+ Calculated statistic for both sides of `padded`.
175
+ """
176
+ # Calculate indices of the edges of the area with original values
177
+ left_index = width_pair [0 ]
178
+ right_index = padded .shape [axis ] - width_pair [1 ]
179
+ # as well as its length
180
+ max_length = right_index - left_index
181
+
182
+ # Limit stat_lengths to max_length
183
+ left_length , right_length = length_pair
184
+
185
+ # Calculate statistic for the left side
186
+ left_length = (
187
+ minimum (left_length , max_length ) if left_length is not None else max_length
188
+ )
189
+ left_slice = _slice_at_axis (slice (left_index , left_index + left_length ), axis )
190
+ left_chunk = padded [left_slice ]
191
+ left_stat = stat_func (left_chunk , axis = axis , keepdims = True )
192
+ if left_length is None and right_length is None :
193
+ # We could also return early in the more general case of left_length == right_length, but we don't necessarily
194
+ # know these shapes.
195
+ # TODO: Add rewrite to simplify in this case
196
+ return left_stat , left_stat
197
+
198
+ # Calculate statistic for the right side
199
+ right_length = (
200
+ minimum (right_length , max_length ) if right_length is not None else max_length
201
+ )
202
+ right_slice = _slice_at_axis (slice (right_index - right_length , right_index ), axis )
203
+ right_chunk = padded [right_slice ]
204
+ right_stat = stat_func (right_chunk , axis = axis , keepdims = True )
205
+
206
+ return left_stat , right_stat
207
+
208
+
209
+ def _stat_pad (
210
+ x : TensorVariable , pad_width : TensorVariable , stat_func , stat_length = None
211
+ ):
212
+ padded , area_slice , pad_width = _symbolic_pad (x , pad_width )
213
+ if stat_length is None :
214
+ stat_length = [[None , None ]] * padded .ndim
215
+ else :
216
+ stat_length = broadcast_to (stat_length , as_tensor ((padded .ndim , 2 )))
217
+
218
+ for axis in range (padded .ndim ):
219
+ width_pair = pad_width [axis ]
220
+ length_pair = stat_length [axis ]
221
+ dim_shape = padded .shape [axis ]
222
+
223
+ left_stat , right_stat = _get_stats (
224
+ padded , axis , width_pair , length_pair , stat_func
225
+ )
226
+ left_slice , right_slice = _get_padding_slices (dim_shape , width_pair , axis )
227
+ padded = set_subtensor (padded [left_slice ], left_stat )
228
+ padded = set_subtensor (padded [right_slice ], right_stat )
229
+
230
+ return padded
231
+
232
+
70
233
def _linear_ramp_pad (
71
234
x : TensorVariable , pad_width : TensorVariable , end_values : TensorVariable | int = 0
72
- ):
235
+ ) -> TensorVariable :
73
236
padded , area_slice , pad_width = _symbolic_pad (x , pad_width )
74
- end_values = broadcast_to (end_values , (padded .ndim , 2 ))
237
+ end_values = as_tensor (end_values )
238
+ end_values = broadcast_to (end_values , as_tensor ((padded .ndim , 2 )))
239
+
75
240
for axis in range (padded .ndim ):
76
241
width_pair = pad_width [axis ]
77
242
end_value_pair = end_values [axis ]
78
243
edge_pair = _get_edges (padded , axis , width_pair )
79
244
dim_shape = padded .shape [axis ]
80
245
left_slice , right_slice = _get_padding_slices (dim_shape , width_pair , axis )
81
246
82
- # pt.linspace doesn't have the endpoint kwarg, so need to take one extra step then slice it away
83
- left_ramp = linspace (
84
- start = end_value_pair [ 0 ] ,
85
- end = specify_broadcastable (edge_pair [ 0 ] , axis ).squeeze (axis ),
86
- steps = width_pair [ 0 ] + 1 ,
87
- )[: - 1 ]
88
- right_ramp = linspace (
89
- start = end_value_pair [ 1 ] ,
90
- end = specify_broadcastable ( edge_pair [ 1 ], axis ). squeeze ( axis ),
91
- steps = width_pair [ 1 ] + 1 ,
92
- )[: - 1 ]
93
- right_ramp = right_ramp [ _slice_at_axis ( slice ( None , None , - 1 ), axis )]
94
-
95
- # FIXME: This swapaxes is needed because the shapes of the linspaces don't "rotate" with
96
- # the different dimensions. But this makes the non-active dimensions backwards in the padding.
97
- padded = set_subtensor (padded [left_slice ], swapaxes ( left_ramp , 0 , axis ) )
98
- padded = set_subtensor (padded [right_slice ], swapaxes ( right_ramp , 0 , axis ) )
247
+ left_ramp , right_ramp = (
248
+ linspace (
249
+ start = end_value ,
250
+ stop = specify_broadcastable (edge , axis ).squeeze (axis ),
251
+ num = width ,
252
+ endpoint = False ,
253
+ dtype = padded . dtype ,
254
+ axis = axis ,
255
+ )
256
+ for end_value , edge , width in zip ( end_value_pair , edge_pair , width_pair )
257
+ )
258
+
259
+ # Reverse the direction of the ramp for the "right" side
260
+ right_ramp = right_ramp [ _slice_at_axis ( slice ( None , None , - 1 ), axis )] # type: ignore
261
+
262
+ padded = set_subtensor (padded [left_slice ], left_ramp )
263
+ padded = set_subtensor (padded [right_slice ], right_ramp )
99
264
100
265
return padded
101
266
102
267
103
- def pad (x , pad_width , mode = "constant" , ** kwargs ):
268
+ def pad (x : TensorLike , pad_width : TensorLike , mode : PadMode = "constant" , ** kwargs ):
104
269
allowed_kwargs = {
105
- "empty" : [],
106
270
"edge" : [],
107
271
"wrap" : [],
108
272
"constant" : ["constant_values" ],
@@ -115,16 +279,24 @@ def pad(x, pad_width, mode="constant", **kwargs):
115
279
"symmetric" : ["reflect_type" ],
116
280
}
117
281
118
- if any (value not in allowed_kwargs [mode ] for value in kwargs .values ()):
282
+ if any (value not in allowed_kwargs [mode ] for value in kwargs .keys ()):
119
283
raise ValueError (
120
284
f"Invalid keyword arguments for mode '{ mode } ': { kwargs .keys ()} "
121
285
)
286
+ x = as_tensor (x )
287
+ pad_width = as_tensor (pad_width )
122
288
123
289
if mode == "constant" :
124
- constant_values = kwargs .pop ("constant_values" , 0 )
290
+ constant_values = as_tensor ( kwargs .pop ("constant_values" , 0 ) )
125
291
return _constant_pad (x , pad_width , constant_values )
126
292
elif mode == "edge" :
127
293
return _edge_pad (x , pad_width )
294
+ elif mode in ["maximum" , "minimum" , "mean" , "median" ]:
295
+ if mode == "median" :
296
+ # TODO: pt.quantile? pt.median?
297
+ raise NotImplementedError ("Median padding not implemented" )
298
+ stat_func = stat_funcs [mode ]
299
+ return _stat_pad (x , pad_width , stat_func , ** kwargs )
128
300
elif mode == "linear_ramp" :
129
301
end_values = kwargs .pop ("end_values" , 0 )
130
302
return _linear_ramp_pad (x , pad_width , end_values )
0 commit comments