26
26
concatenate ,
27
27
constant ,
28
28
expand_dims ,
29
- full_like ,
30
29
stack ,
31
30
switch ,
32
- take_along_axis ,
33
31
)
34
32
from pytensor .tensor .blockwise import Blockwise , vectorize_node_fallback
35
33
from pytensor .tensor .elemwise import (
@@ -2885,11 +2883,11 @@ def percentile(input, q, axis=None):
2885
2883
axis: None or int or list of int, optional
2886
2884
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
2887
2885
"""
2888
- input = as_tensor_variable (input )
2889
- input_ndim = input .type .ndim
2886
+ x = as_tensor_variable (input )
2887
+ x_ndim = x .type .ndim
2890
2888
2891
2889
if axis is None :
2892
- axis = list (range (input_ndim ))
2890
+ axis = list (range (x_ndim ))
2893
2891
elif isinstance (axis , (int | np .integer )):
2894
2892
axis = [axis ]
2895
2893
elif isinstance (axis , np .ndarray ) and axis .ndim == 0 :
@@ -2898,17 +2896,17 @@ def percentile(input, q, axis=None):
2898
2896
axis = [int (a ) for a in axis ]
2899
2897
2900
2898
# Compute the shape of the remaining axes
2901
- new_axes_order = [i for i in range (input .ndim ) if i not in axis ] + list (axis )
2902
- input = input .dimshuffle (new_axes_order )
2903
- input_shape = shape (input )
2904
- remaining_axis_size = input_shape [: input .ndim - len (axis )]
2905
- flattened_axis_size = prod (input_shape [input .ndim - len (axis ) :])
2906
- input = input .reshape (concatenate ([remaining_axis_size , [flattened_axis_size ]]))
2907
- axis = - 1
2899
+ new_axes_order = [i for i in range (x .ndim ) if i not in axis ] + list (axis )
2900
+ x = x .dimshuffle (new_axes_order )
2901
+ input_shape = shape (x )
2902
+ remaining_axis_size = input_shape [: x .ndim - len (axis )]
2903
+ x = x .reshape ((* remaining_axis_size , - 1 ))
2908
2904
2909
2905
# Sort the input tensor along the specified axis
2910
- sorted_input = input .sort (axis = axis )
2911
- input_shape = input .shape [axis ]
2906
+ sorted_input = x .sort (axis = - 1 )
2907
+ slices1 = [slice (None )] * sorted_input .ndim
2908
+ slices2 = [slice (None )] * sorted_input .ndim
2909
+ input_shape = x .shape [- 1 ]
2912
2910
2913
2911
if isinstance (q , (int | float )):
2914
2912
q = [q ]
@@ -2923,18 +2921,15 @@ def percentile(input, q, axis=None):
2923
2921
k_floor = floor (k ).astype ("int64" )
2924
2922
k_ceil = ceil (k ).astype ("int64" )
2925
2923
2926
- indices1 = expand_dims (
2927
- full_like (sorted_input .take (0 , axis = axis ), k_floor ), axis
2928
- )
2929
- indices2 = expand_dims (full_like (sorted_input .take (0 , axis = axis ), k_ceil ), axis )
2930
-
2931
- val1 = take_along_axis (sorted_input , indices1 , axis = axis )
2932
- val2 = take_along_axis (sorted_input , indices2 , axis = axis )
2924
+ slices1 [- 1 ] = slice (k_floor , k_floor + 1 )
2925
+ slices2 [- 1 ] = slice (k_ceil , k_ceil + 1 )
2926
+ val1 = sorted_input [tuple (slices1 )]
2927
+ val2 = sorted_input [tuple (slices2 )]
2933
2928
2934
2929
d = k - k_floor
2935
2930
percentile_val = val1 + d * (val2 - val1 )
2936
2931
2937
- result .append (percentile_val .squeeze (axis = axis ))
2932
+ result .append (percentile_val .squeeze (axis = - 1 ))
2938
2933
2939
2934
if len (result ) == 1 :
2940
2935
result = result [0 ]
0 commit comments