@@ -2926,99 +2926,85 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
2926
2926
return x
2927
2927
2928
2928
2929
- def quantile (input , q , axis = None ):
2929
+ def quantile (x : TensorLike , q : float | list [ float ] , axis = None ) -> TensorVariable :
2930
2930
"""
2931
- Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation .
2931
+ Computes the q-th quantile along the given axis(es) of a tensor `input`.
2932
2932
2933
2933
Parameters
2934
2934
----------
2935
- input : TensorVariable
2935
+ x : TensorVariable
2936
2936
The input tensor.
2937
- q: float or list of floats
2938
- Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
2939
- axis: None or int or list of int, optional
2940
- Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
2937
+ q: float or list of float
2938
+ The quantile(s) to compute, which must be between 0 and 1 inclusive.
2939
+ 0 corresponds to the minimum, 0.5 to the median, and 1 to the maximum.
2940
+ axis: None or int or (list of int) (see `Sum`)
2941
+ Compute the quantile along this axis of the tensor.
2942
+ None means all axes (like numpy).
2941
2943
"""
2942
- x = as_tensor_variable (input )
2943
- x_ndim = x .type .ndim
2944
2944
2945
+ x = as_tensor_variable (x )
2946
+ q = as_tensor_variable (q )
2947
+ x_ndim = x .type .ndim
2945
2948
if axis is None :
2946
2949
axis = list (range (x_ndim ))
2947
- elif isinstance (axis , (int | np .integer )):
2948
- axis = [axis ]
2949
- elif isinstance (axis , np .ndarray ) and axis .ndim == 0 :
2950
- axis = [int (axis )]
2951
2950
else :
2952
- axis = [int (a ) for a in axis ]
2953
-
2954
- # Compute the shape of the remaining axes
2955
- new_axes_order = [i for i in range (x .ndim ) if i not in axis ] + list (axis )
2956
- x = x .dimshuffle (new_axes_order )
2957
- input_shape = shape (x )
2958
- remaining_axis_size = input_shape [: x .ndim - len (axis )]
2959
- x = x .reshape ((* remaining_axis_size , - 1 ))
2960
-
2961
- # Sort the input tensor along the specified axis
2962
- sorted_input = x .sort (axis = - 1 )
2963
- slices1 = [slice (None )] * sorted_input .ndim
2964
- slices2 = [slice (None )] * sorted_input .ndim
2965
- input_shape = x .shape [- 1 ]
2951
+ axis = list (normalize_axis_tuple (axis , x_ndim ))
2966
2952
2967
- if isinstance ( q , ( int | float )):
2968
- q = [ q ]
2953
+ non_axis = [ i for i in range ( x_ndim ) if i not in axis ]
2954
+ non_axis_shape = [ x . shape [ i ] for i in non_axis ]
2969
2955
2970
- for quantile in q :
2971
- if quantile < 0 or quantile > 1 :
2972
- raise ValueError ("Quantiles must be in the range [0, 1]" )
2956
+ # Put axis at the end and unravel them
2957
+ x_raveled = x .transpose (* non_axis , * axis )
2958
+ if len (axis ) > 1 :
2959
+ x_raveled = x_raveled .reshape ((* non_axis_shape , - 1 ))
2960
+ raveled_size = x_raveled .shape [- 1 ]
2973
2961
2974
- result = []
2975
- for quantile in q :
2976
- k = (quantile ) * (input_shape - 1 )
2977
- k_floor = floor (k ).astype ("int64" )
2978
- k_ceil = ceil (k ).astype ("int64" )
2962
+ # Ensure q is between 0 and 1
2963
+ q = clip (q , 0.0 , 1.0 )
2979
2964
2980
- slices1 [- 1 ] = slice (k_floor , k_floor + 1 )
2981
- slices2 [- 1 ] = slice (k_ceil , k_ceil + 1 )
2982
- val1 = sorted_input [tuple (slices1 )]
2983
- val2 = sorted_input [tuple (slices2 )]
2965
+ # Compute quantile indices
2966
+ k = (q * (raveled_size - 1 )).astype ("int64" )
2967
+ k_float = q * (raveled_size - 1 )
2984
2968
2985
- d = k - k_floor
2986
- quantile_val = val1 + d * ( val2 - val1 )
2969
+ # Sort the input tensor along the specified axis
2970
+ x_sorted = x_raveled . sort ( axis = - 1 )
2987
2971
2988
- result .append (quantile_val .squeeze (axis = - 1 ))
2972
+ # Get the values at index k and k + 1 for linear interpolation
2973
+ k_values = x_sorted [..., k ]
2974
+ kp1_values = x_sorted [..., minimum (k + 1 , raveled_size - 1 )]
2989
2975
2990
- if len (result ) == 1 :
2991
- result = result [0 ]
2992
- else :
2993
- result = stack (result )
2976
+ # Interpolation between the two values if needed
2977
+ frac = k_float - k .astype (k_float .dtype )
2978
+ quantile_value = (1 - frac ) * k_values + frac * kp1_values
2994
2979
2995
- result .name = "quantile"
2996
- return result
2980
+ return quantile_value
2997
2981
2998
2982
2999
- def percentile (input , q , axis = None ):
2983
+ def percentile (x : TensorLike , p : float | list [ float ] , axis = None ) -> TensorVariable :
3000
2984
"""
3001
- Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation .
2985
+ Computes the p-th percentile along the given axis(es) of a tensor `input`.
3002
2986
3003
2987
Parameters
3004
2988
----------
3005
- input : TensorVariable
2989
+ x : TensorVariable
3006
2990
The input tensor.
3007
- q: float or list of floats
3008
- Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
3009
- axis: None or int or list of int, optional
3010
- Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
3011
- """
3012
- if isinstance (q , (int | float )):
3013
- q = [q ]
3014
-
3015
- for percentile in q :
3016
- if percentile < 0 or percentile > 100 :
3017
- raise ValueError ("Percentiles must be in the range [0, 100]" )
2991
+ p: float or list of float
2992
+ The percentile(s) to compute, which must be between 0 and 100 inclusive.
2993
+ 0 corresponds to the minimum, 50 to the median, and 100 to the maximum.
2994
+ axis: None or int or (list of int) (see `Sum`)
2995
+ Compute the percentile along this axis of the tensor.
2996
+ None means all axes (like numpy).
3018
2997
3019
- quantiles = [x / 100 for x in q ]
2998
+ Returns
2999
+ -------
3000
+ TensorVariable
3001
+ The computed percentile values.
3002
+ """
3003
+ # Convert percentiles (0-100) to quantiles (0-1)
3004
+ q = as_tensor_variable (p ) / 100.0
3020
3005
3021
- return quantile (input , quantiles , axis )
3006
+ # Call the quantile function
3007
+ return quantile (x , q , axis = axis )
3022
3008
3023
3009
3024
3010
# NumPy logical aliases
0 commit comments