Skip to content

Commit 305eb39

Browse files
authored
ChiSquared now returns a Gamma random variable (#7007)
1 parent cac99d9 commit 305eb39

File tree

4 files changed

+17
-48
lines changed

4 files changed

+17
-48
lines changed

pymc/distributions/continuous.py

+12-19
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
BetaRV,
3939
_gamma,
4040
cauchy,
41-
chisquare,
4241
exponential,
4342
gumbel,
4443
halfcauchy,
@@ -56,7 +55,7 @@
5655
from pytensor.tensor.random.op import RandomVariable
5756
from pytensor.tensor.variable import TensorConstant
5857

59-
from pymc.logprob.abstract import _logcdf_helper, _logprob_helper
58+
from pymc.logprob.abstract import _logprob_helper
6059
from pymc.logprob.basic import icdf
6160

6261
try:
@@ -2374,16 +2373,21 @@ def logcdf(value, alpha, beta):
23742373
)
23752374

23762375

2377-
class ChiSquared(PositiveContinuous):
2376+
class ChiSquared:
23782377
r"""
23792378
:math:`\chi^2` log-likelihood.
23802379
2380+
This is the distribution from the sum of the squares of :math:`\nu` independent standard normal random variables or a special
2381+
case of the gamma distribution with :math:`\alpha = \nu/2` and :math:`\beta = 1/2`.
2382+
23812383
The pdf of this distribution is
23822384
23832385
.. math::
23842386
23852387
f(x \mid \nu) = \frac{x^{(\nu-2)/2}e^{-x/2}}{2^{\nu/2}\Gamma(\nu/2)}
23862388
2389+
Read more about the :math:`\chi^2` distribution at https://en.wikipedia.org/wiki/Chi-squared_distribution
2390+
23872391
.. plot::
23882392
:context: close-figs
23892393
@@ -2413,24 +2417,13 @@ class ChiSquared(PositiveContinuous):
24132417
nu : tensor_like of float
24142418
Degrees of freedom (nu > 0).
24152419
"""
2416-
rv_op = chisquare
2417-
2418-
@classmethod
2419-
def dist(cls, nu, *args, **kwargs):
2420-
nu = pt.as_tensor_variable(floatX(nu))
2421-
return super().dist([nu], *args, **kwargs)
24222420

2423-
def moment(rv, size, nu):
2424-
moment = nu
2425-
if not rv_size_is_none(size):
2426-
moment = pt.full(size, moment)
2427-
return moment
2421+
def __new__(cls, name, nu, **kwargs):
2422+
return Gamma(name, alpha=nu / 2, beta=1 / 2, **kwargs)
24282423

2429-
def logp(value, nu):
2430-
return _logprob_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value)
2431-
2432-
def logcdf(value, nu):
2433-
return _logcdf_helper(Gamma.dist(alpha=nu / 2, beta=0.5), value)
2424+
@classmethod
2425+
def dist(cls, nu, **kwargs):
2426+
return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs)
24342427

24352428

24362429
# TODO: Remove this once logp for multiplication is working!

tests/distributions/test_continuous.py

-26
Original file line numberDiff line numberDiff line change
@@ -1132,19 +1132,6 @@ def test_beta_moment(self, alpha, beta, size, expected):
11321132
pm.Beta("x", alpha=alpha, beta=beta, size=size)
11331133
assert_moment_is_expected(model, expected)
11341134

1135-
@pytest.mark.parametrize(
1136-
"nu, size, expected",
1137-
[
1138-
(1, None, 1),
1139-
(1, 5, np.full(5, 1)),
1140-
(np.arange(1, 6), None, np.arange(1, 6)),
1141-
],
1142-
)
1143-
def test_chisquared_moment(self, nu, size, expected):
1144-
with pm.Model() as model:
1145-
pm.ChiSquared("x", nu=nu, size=size)
1146-
assert_moment_is_expected(model, expected)
1147-
11481135
@pytest.mark.parametrize(
11491136
"lam, size, expected",
11501137
[
@@ -2243,19 +2230,6 @@ class TestInverseGammaMuSigma(BaseTestDistributionRandom):
22432230
checks_to_run = ["check_pymc_params_match_rv_op"]
22442231

22452232

2246-
class TestChiSquared(BaseTestDistributionRandom):
2247-
pymc_dist = pm.ChiSquared
2248-
pymc_dist_params = {"nu": 2.0}
2249-
expected_rv_op_params = {"nu": 2.0}
2250-
reference_dist_params = {"df": 2.0}
2251-
reference_dist = seeded_numpy_distribution_builder("chisquare")
2252-
checks_to_run = [
2253-
"check_pymc_params_match_rv_op",
2254-
"check_pymc_draws_match_reference",
2255-
"check_rv_size",
2256-
]
2257-
2258-
22592233
class TestLogistic(BaseTestDistributionRandom):
22602234
pymc_dist = pm.Logistic
22612235
pymc_dist_params = {"mu": 1.0, "s": 2.0}

tests/logprob/test_transform_value.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from pytensor.graph import FunctionGraph
2727
from pytensor.graph.basic import equal_computations
2828

29+
import pymc as pm
30+
2931
from pymc.distributions.transforms import _default_transform, log, logodds
3032
from pymc.logprob import conditional_logp
3133
from pymc.logprob.abstract import MeasurableVariable, _logprob
@@ -154,7 +156,7 @@ def test_original_values_output_dict():
154156
(),
155157
),
156158
(
157-
pt.random.chisquare,
159+
pm.ChiSquared.dist,
158160
(1.5,),
159161
lambda df: sp.stats.chi2(df),
160162
(),

tests/logprob/test_transforms.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
from pytensor.graph.basic import equal_computations
4545

46-
from pymc.distributions.continuous import Cauchy
46+
from pymc.distributions.continuous import Cauchy, ChiSquared
4747
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
4848
from pymc.logprob.transforms import (
4949
ArccoshTransform,
@@ -431,7 +431,7 @@ def test_sqr_transform(self):
431431

432432
def test_sqrt_transform(self):
433433
# The sqrt of a chisquare with n df is a chi distribution with n df
434-
x_rv = pt.sqrt(pt.random.chisquare(df=3, size=(4,)))
434+
x_rv = pt.sqrt(ChiSquared.dist(nu=3, size=(4,)))
435435
x_rv.name = "x"
436436

437437
x_vv = x_rv.clone()

0 commit comments

Comments
 (0)