Skip to content

Commit dcdeef7

Browse files
Vectorised percentile/quantile
1 parent 17e6aa3 commit dcdeef7

File tree

2 files changed

+53
-66
lines changed

2 files changed

+53
-66
lines changed

pytensor/tensor/math.py

+52-66
Original file line numberDiff line numberDiff line change
@@ -2926,99 +2926,85 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
29262926
return x
29272927

29282928

2929-
def quantile(input, q, axis=None):
2929+
def quantile(x: TensorLike, q: float | list[float], axis=None) -> TensorVariable:
29302930
"""
2931-
Computes the quantile along the given axis(es) of a tensor `input` using linear interpolation.
2931+
Computes the q-th quantile along the given axis(es) of a tensor `input`.
29322932
29332933
Parameters
29342934
----------
2935-
input: TensorVariable
2935+
x: TensorVariable
29362936
The input tensor.
2937-
q: float or list of floats
2938-
Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive.
2939-
axis: None or int or list of int, optional
2940-
Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array.
2937+
q: float or list of float
2938+
The quantile(s) to compute, which must be between 0 and 1 inclusive.
2939+
0 corresponds to the minimum, 0.5 to the median, and 1 to the maximum.
2940+
axis: None or int or (list of int) (see `Sum`)
2941+
Compute the quantile along this axis of the tensor.
2942+
None means all axes (like numpy).
29412943
"""
2942-
x = as_tensor_variable(input)
2943-
x_ndim = x.type.ndim
29442944

2945+
x = as_tensor_variable(x)
2946+
q = as_tensor_variable(q)
2947+
x_ndim = x.type.ndim
29452948
if axis is None:
29462949
axis = list(range(x_ndim))
2947-
elif isinstance(axis, (int | np.integer)):
2948-
axis = [axis]
2949-
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
2950-
axis = [int(axis)]
29512950
else:
2952-
axis = [int(a) for a in axis]
2953-
2954-
# Compute the shape of the remaining axes
2955-
new_axes_order = [i for i in range(x.ndim) if i not in axis] + list(axis)
2956-
x = x.dimshuffle(new_axes_order)
2957-
input_shape = shape(x)
2958-
remaining_axis_size = input_shape[: x.ndim - len(axis)]
2959-
x = x.reshape((*remaining_axis_size, -1))
2960-
2961-
# Sort the input tensor along the specified axis
2962-
sorted_input = x.sort(axis=-1)
2963-
slices1 = [slice(None)] * sorted_input.ndim
2964-
slices2 = [slice(None)] * sorted_input.ndim
2965-
input_shape = x.shape[-1]
2951+
axis = list(normalize_axis_tuple(axis, x_ndim))
29662952

2967-
if isinstance(q, (int | float)):
2968-
q = [q]
2953+
non_axis = [i for i in range(x_ndim) if i not in axis]
2954+
non_axis_shape = [x.shape[i] for i in non_axis]
29692955

2970-
for quantile in q:
2971-
if quantile < 0 or quantile > 1:
2972-
raise ValueError("Quantiles must be in the range [0, 1]")
2956+
# Put axis at the end and unravel them
2957+
x_raveled = x.transpose(*non_axis, *axis)
2958+
if len(axis) > 1:
2959+
x_raveled = x_raveled.reshape((*non_axis_shape, -1))
2960+
raveled_size = x_raveled.shape[-1]
29732961

2974-
result = []
2975-
for quantile in q:
2976-
k = (quantile) * (input_shape - 1)
2977-
k_floor = floor(k).astype("int64")
2978-
k_ceil = ceil(k).astype("int64")
2962+
# Ensure q is between 0 and 1
2963+
q = clip(q, 0.0, 1.0)
29792964

2980-
slices1[-1] = slice(k_floor, k_floor + 1)
2981-
slices2[-1] = slice(k_ceil, k_ceil + 1)
2982-
val1 = sorted_input[tuple(slices1)]
2983-
val2 = sorted_input[tuple(slices2)]
2965+
# Compute quantile indices
2966+
k = (q * (raveled_size - 1)).astype("int64")
2967+
k_float = q * (raveled_size - 1)
29842968

2985-
d = k - k_floor
2986-
quantile_val = val1 + d * (val2 - val1)
2969+
# Sort the input tensor along the specified axis
2970+
x_sorted = x_raveled.sort(axis=-1)
29872971

2988-
result.append(quantile_val.squeeze(axis=-1))
2972+
# Get the values at index k and k + 1 for linear interpolation
2973+
k_values = x_sorted[..., k]
2974+
kp1_values = x_sorted[..., minimum(k + 1, raveled_size - 1)]
29892975

2990-
if len(result) == 1:
2991-
result = result[0]
2992-
else:
2993-
result = stack(result)
2976+
# Interpolation between the two values if needed
2977+
frac = k_float - k.astype(k_float.dtype)
2978+
quantile_value = (1 - frac) * k_values + frac * kp1_values
29942979

2995-
result.name = "quantile"
2996-
return result
2980+
return quantile_value
29972981

29982982

2999-
def percentile(input, q, axis=None):
2983+
def percentile(x: TensorLike, p: float | list[float], axis=None) -> TensorVariable:
30002984
"""
3001-
Computes the percentile along the given axis(es) of a tensor `input` using linear interpolation.
2985+
Computes the p-th percentile along the given axis(es) of a tensor `input`.
30022986
30032987
Parameters
30042988
----------
3005-
input: TensorVariable
2989+
x: TensorVariable
30062990
The input tensor.
3007-
q: float or list of floats
3008-
Percentile or sequence of percentiles to compute, which must be between 0 and 100 inclusive.
3009-
axis: None or int or list of int, optional
3010-
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
3011-
"""
3012-
if isinstance(q, (int | float)):
3013-
q = [q]
3014-
3015-
for percentile in q:
3016-
if percentile < 0 or percentile > 100:
3017-
raise ValueError("Percentiles must be in the range [0, 100]")
2991+
p: float or list of float
2992+
The percentile(s) to compute, which must be between 0 and 100 inclusive.
2993+
0 corresponds to the minimum, 50 to the median, and 100 to the maximum.
2994+
axis: None or int or (list of int) (see `Sum`)
2995+
Compute the percentile along this axis of the tensor.
2996+
None means all axes (like numpy).
30182997
3019-
quantiles = [x / 100 for x in q]
2998+
Returns
2999+
-------
3000+
TensorVariable
3001+
The computed percentile values.
3002+
"""
3003+
# Convert percentiles (0-100) to quantiles (0-1)
3004+
q = as_tensor_variable(p) / 100.0
30203005

3021-
return quantile(input, quantiles, axis)
3006+
# Call the quantile function
3007+
return quantile(x, q, axis=axis)
30223008

30233009

30243010
# NumPy logical aliases

tests/tensor/test_math.py

+1
Original file line numberDiff line numberDiff line change
@@ -3805,6 +3805,7 @@ def test_percentile(ndim, axis, q):
38053805
(3, None, [0.25, 0.75]),
38063806
(3, 0, 0.5),
38073807
(3, (1, 2), 0.5),
3808+
(3, (1, 2), [0.1, 0.9]),
38083809
],
38093810
)
38103811
def test_quantile(ndim, axis, q):

0 commit comments

Comments
 (0)