Skip to content

Commit 77d8604

Browse files
Add tests for percentile and quantile
1 parent 15a3718 commit 77d8604

File tree

2 files changed

+85
-22
lines changed

2 files changed

+85
-22
lines changed

pytensor/tensor/math.py

+17-22
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@
2626
concatenate,
2727
constant,
2828
expand_dims,
29-
full_like,
3029
stack,
3130
switch,
32-
take_along_axis,
3331
)
3432
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
3533
from pytensor.tensor.elemwise import (
@@ -2885,11 +2883,11 @@ def percentile(input, q, axis=None):
28852883
axis: None or int or list of int, optional
28862884
Axis or axes along which the percentiles are computed. The default is to compute the percentile(s) along a flattened version of the array.
28872885
"""
2888-
input = as_tensor_variable(input)
2889-
input_ndim = input.type.ndim
2886+
x = as_tensor_variable(input)
2887+
x_ndim = x.type.ndim
28902888

28912889
if axis is None:
2892-
axis = list(range(input_ndim))
2890+
axis = list(range(x_ndim))
28932891
elif isinstance(axis, (int | np.integer)):
28942892
axis = [axis]
28952893
elif isinstance(axis, np.ndarray) and axis.ndim == 0:
@@ -2898,17 +2896,17 @@ def percentile(input, q, axis=None):
28982896
axis = [int(a) for a in axis]
28992897

29002898
# Compute the shape of the remaining axes
2901-
new_axes_order = [i for i in range(input.ndim) if i not in axis] + list(axis)
2902-
input = input.dimshuffle(new_axes_order)
2903-
input_shape = shape(input)
2904-
remaining_axis_size = input_shape[: input.ndim - len(axis)]
2905-
flattened_axis_size = prod(input_shape[input.ndim - len(axis) :])
2906-
input = input.reshape(concatenate([remaining_axis_size, [flattened_axis_size]]))
2907-
axis = -1
2899+
new_axes_order = [i for i in range(x.ndim) if i not in axis] + list(axis)
2900+
x = x.dimshuffle(new_axes_order)
2901+
input_shape = shape(x)
2902+
remaining_axis_size = input_shape[: x.ndim - len(axis)]
2903+
x = x.reshape((*remaining_axis_size, -1))
29082904

29092905
# Sort the input tensor along the specified axis
2910-
sorted_input = input.sort(axis=axis)
2911-
input_shape = input.shape[axis]
2906+
sorted_input = x.sort(axis=-1)
2907+
slices1 = [slice(None)] * sorted_input.ndim
2908+
slices2 = [slice(None)] * sorted_input.ndim
2909+
input_shape = x.shape[-1]
29122910

29132911
if isinstance(q, (int | float)):
29142912
q = [q]
@@ -2923,18 +2921,15 @@ def percentile(input, q, axis=None):
29232921
k_floor = floor(k).astype("int64")
29242922
k_ceil = ceil(k).astype("int64")
29252923

2926-
indices1 = expand_dims(
2927-
full_like(sorted_input.take(0, axis=axis), k_floor), axis
2928-
)
2929-
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k_ceil), axis)
2930-
2931-
val1 = take_along_axis(sorted_input, indices1, axis=axis)
2932-
val2 = take_along_axis(sorted_input, indices2, axis=axis)
2924+
slices1[-1] = slice(k_floor, k_floor + 1)
2925+
slices2[-1] = slice(k_ceil, k_ceil + 1)
2926+
val1 = sorted_input[tuple(slices1)]
2927+
val2 = sorted_input[tuple(slices2)]
29332928

29342929
d = k - k_floor
29352930
percentile_val = val1 + d * (val2 - val1)
29362931

2937-
result.append(percentile_val.squeeze(axis=axis))
2932+
result.append(percentile_val.squeeze(axis=-1))
29382933

29392934
if len(result) == 1:
29402935
result = result[0]

tests/tensor/test_math.py

+68
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,11 @@
102102
neg,
103103
neq,
104104
outer,
105+
percentile,
105106
polygamma,
106107
power,
107108
ptp,
109+
quantile,
108110
rad2deg,
109111
reciprocal,
110112
round_half_away_from_zero,
@@ -3732,3 +3734,69 @@ def test_nan_to_num(nan, posinf, neginf):
37323734
out,
37333735
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
37343736
)
3737+
3738+
3739+
@pytest.mark.parametrize(
3740+
"ndim, axis, q",
3741+
[
3742+
(2, None, 50),
3743+
(2, 1, 33),
3744+
(2, (0, 1), 50),
3745+
(3, (1, 2), 50),
3746+
(4, (1, 3, 0), 25),
3747+
(2, None, [25, 50, 75]),
3748+
(3, (1, 2), [10, 90]),
3749+
(3, 1, 75),
3750+
(3, 0, 50),
3751+
],
3752+
)
3753+
def test_percentile(ndim, axis, q):
3754+
shape = tuple(np.arange(1, ndim + 1))
3755+
data = np.random.rand(*shape)
3756+
x = tensor(shape=np.array(data).shape)
3757+
f = function([x], percentile(x, q, axis=axis))
3758+
result = f(data.astype(x.dtype))
3759+
expected = np.percentile(data.astype(x.dtype), q, axis=axis)
3760+
assert np.allclose(result, expected)
3761+
3762+
3763+
@pytest.mark.parametrize(
3764+
"ndim, axis, q",
3765+
[
3766+
(2, None, 0.5),
3767+
(2, None, [0.25, 0.75]),
3768+
(2, 0, 0.5),
3769+
(2, (0, 1), 0.5),
3770+
(3, None, 0.5),
3771+
(3, None, [0.25, 0.75]),
3772+
(3, 0, 0.5),
3773+
(3, (1, 2), 0.5),
3774+
],
3775+
)
3776+
def test_quantile(ndim, axis, q):
3777+
shape = tuple(np.random.randint(2, 6) for _ in range(ndim))
3778+
data = np.random.rand(*shape)
3779+
3780+
x = tensor(dtype="float64", shape=(None,) * ndim)
3781+
f = function([x], quantile(x, q, axis=axis))
3782+
3783+
result = f(data.astype(x.dtype))
3784+
expected = np.quantile(data.astype(x.dtype), q, axis=axis)
3785+
3786+
assert np.allclose(result, expected)
3787+
3788+
3789+
@pytest.mark.parametrize(
3790+
"ndim, axis, q, is_percentile",
3791+
[
3792+
(2, None, [50, 120], True),
3793+
(2, 1, -0.5, False),
3794+
],
3795+
)
3796+
def test_invalid_percentile_quamtile(ndim, axis, q, is_percentile):
3797+
x = tensor(dtype="float64", shape=(None,) * ndim)
3798+
with pytest.raises(ValueError):
3799+
if is_percentile:
3800+
percentile(x, q, axis)
3801+
else:
3802+
quantile(x, q, axis)

0 commit comments

Comments
 (0)