Skip to content

Commit cc8d4d9

Browse files
Add tests for percentile and quantile
1 parent 15a3718 commit cc8d4d9

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

tests/tensor/test_math.py

+68
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,11 @@
102102
neg,
103103
neq,
104104
outer,
105+
percentile,
105106
polygamma,
106107
power,
107108
ptp,
109+
quantile,
108110
rad2deg,
109111
reciprocal,
110112
round_half_away_from_zero,
@@ -3732,3 +3734,69 @@ def test_nan_to_num(nan, posinf, neginf):
37323734
out,
37333735
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
37343736
)
3737+
3738+
3739+
@pytest.mark.parametrize(
3740+
"ndim, axis, q",
3741+
[
3742+
(2, None, 50),
3743+
(2, 1, 33),
3744+
(2, (0, 1), 50),
3745+
(3, (1, 2), 50),
3746+
(4, (1, 3, 0), 25),
3747+
(2, None, [25, 50, 75]),
3748+
(3, (1, 2), [10, 90]),
3749+
(3, 1, 75),
3750+
(3, 0, 50),
3751+
],
3752+
)
3753+
def test_percentile(ndim, axis, q):
3754+
shape = tuple(np.arange(1, ndim + 1))
3755+
data = np.random.rand(*shape)
3756+
x = tensor(shape=np.array(data).shape)
3757+
f = function([x], percentile(x, q, axis=axis))
3758+
result = f(data.astype(x.dtype))
3759+
expected = np.percentile(data.astype(x.dtype), q, axis=axis)
3760+
assert np.allclose(result, expected)
3761+
3762+
3763+
@pytest.mark.parametrize(
3764+
"ndim, axis, q",
3765+
[
3766+
(2, None, 0.5),
3767+
(2, None, [0.25, 0.75]),
3768+
(2, 0, 0.5),
3769+
(2, (0, 1), 0.5),
3770+
(3, None, 0.5),
3771+
(3, None, [0.25, 0.75]),
3772+
(3, 0, 0.5),
3773+
(3, (1, 2), 0.5),
3774+
],
3775+
)
3776+
def test_quantile(ndim, axis, q):
3777+
shape = tuple(np.random.randint(2, 6) for _ in range(ndim))
3778+
data = np.random.rand(*shape)
3779+
3780+
x = tensor(dtype="float64", shape=(None,) * ndim)
3781+
f = function([x], quantile(x, q, axis=axis))
3782+
3783+
result = f(data.astype(x.dtype))
3784+
expected = np.quantile(data.astype(x.dtype), q, axis=axis)
3785+
3786+
assert np.allclose(result, expected)
3787+
3788+
3789+
@pytest.mark.parametrize(
3790+
"ndim, axis, q, is_percentile",
3791+
[
3792+
(2, None, [50, 120], True),
3793+
(2, 1, -0.5, False),
3794+
],
3795+
)
3796+
def test_invalid_percentile_quamtile(ndim, axis, q, is_percentile):
3797+
x = tensor(dtype="float64", shape=(None,) * ndim)
3798+
with pytest.raises(ValueError):
3799+
if is_percentile:
3800+
percentile(x, q, axis)
3801+
else:
3802+
quantile(x, q, axis)

0 commit comments

Comments
 (0)