Skip to content

Commit 9e0e584

Browse files
larryshamalamatwiecki
larryshamalama
authored andcommitted
Refactoring the ChiSquared distribution (#4695)
* Refactoring ChiSquared distribution * Refactoring ChiSquared (minor edit) * Refactoring Chisquared (another one-line change) * Trying to rebase/merge my branch with updated upstream v4 * Using aesara chisquare op (r.f. PR #414) and renamed ChiSquared to ChiSquare * Added logpdf & logcdf to the ChiSquare class * Corrected function name * Updating branch * Refactoring ChiSquared: bug fixed, tests work locally * Minor fix: removed float32 specification * ☀️ underflow to -inf seems normal in float32 * Minor fix in documentation
1 parent 4de95c3 commit 9e0e584

File tree

3 files changed

+68
-14
lines changed

3 files changed

+68
-14
lines changed

pymc3/distributions/continuous.py

+41-5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
BetaRV,
2929
WeibullRV,
3030
cauchy,
31+
chisquare,
3132
exponential,
3233
gamma,
3334
gumbel,
@@ -2562,7 +2563,7 @@ def logcdf(value, alpha, beta):
25622563
)
25632564

25642565

2565-
class ChiSquared(Gamma):
2566+
class ChiSquared(PositiveContinuous):
25662567
r"""
25672568
:math:`\chi^2` log-likelihood.
25682569
@@ -2597,13 +2598,48 @@ class ChiSquared(Gamma):
25972598
25982599
Parameters
25992600
----------
2600-
nu: int
2601+
nu: float
26012602
Degrees of freedom (nu > 0).
26022603
"""
2604+
rv_op = chisquare
26032605

2604-
def __init__(self, nu, *args, **kwargs):
2605-
self.nu = nu = at.as_tensor_variable(floatX(nu))
2606-
super().__init__(alpha=nu / 2.0, beta=0.5, *args, **kwargs)
2606+
@classmethod
2607+
def dist(cls, nu, *args, **kwargs):
2608+
nu = at.as_tensor_variable(floatX(nu))
2609+
return super().dist([nu], *args, **kwargs)
2610+
2611+
def logp(value, nu):
2612+
"""
2613+
Calculate log-probability of ChiSquared distribution at specified value.
2614+
2615+
Parameters
2616+
----------
2617+
value: numeric
2618+
Value(s) for which log-probability is calculated. If the log probabilities for multiple
2619+
values are desired the values must be provided in a numpy array or Aesara tensor
2620+
2621+
Returns
2622+
-------
2623+
TensorVariable
2624+
"""
2625+
return Gamma.logp(value, nu / 2, 2)
2626+
2627+
def logcdf(value, nu):
2628+
"""
2629+
Compute the log of the cumulative distribution function for ChiSquared distribution
2630+
at the specified value.
2631+
2632+
Parameters
2633+
----------
2634+
value: numeric or np.ndarray or `TensorVariable`
2635+
Value(s) for which log CDF is calculated. If the log CDF for
2636+
multiple values are desired the values must be provided in a numpy
2637+
array or `TensorVariable`.
2638+
Returns
2639+
-------
2640+
TensorVariable
2641+
"""
2642+
return Gamma.logcdf(value, nu / 2, 2)
26072643

26082644

26092645
# TODO: Remove this once logpt for multiplication is working!

pymc3/tests/test_distributions.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1048,15 +1048,26 @@ def test_half_normal(self):
10481048
lambda value, sigma: sp.halfnorm.logcdf(value, scale=sigma),
10491049
)
10501050

1051-
@pytest.mark.xfail(reason="Distribution not refactored yet")
1052-
def test_chi_squared(self):
1051+
def test_chisquared_logp(self):
10531052
self.check_logp(
10541053
ChiSquared,
10551054
Rplus,
1056-
{"nu": Rplusdunif},
1055+
{"nu": Rplus},
10571056
lambda value, nu: sp.chi2.logpdf(value, df=nu),
10581057
)
10591058

1059+
@pytest.mark.xfail(
1060+
condition=(aesara.config.floatX == "float32"),
1061+
reason="Fails on float32 due to numerical issues",
1062+
)
1063+
def test_chisquared_logcdf(self):
1064+
self.check_logcdf(
1065+
ChiSquared,
1066+
Rplus,
1067+
{"nu": Rplus},
1068+
lambda value, nu: sp.chi2.logcdf(value, df=nu),
1069+
)
1070+
10601071
@pytest.mark.xfail(reason="Distribution not refactored yet")
10611072
def test_wald_logp(self):
10621073
self.check_logp(

pymc3/tests/test_distributions_random.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,6 @@ class TestAsymmetricLaplace(BaseTestCases.BaseTestCase):
276276
params = {"kappa": 1.0, "b": 1.0, "mu": 0.0}
277277

278278

279-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
280-
class TestChiSquared(BaseTestCases.BaseTestCase):
281-
distribution = pm.ChiSquared
282-
params = {"nu": 2.0}
283-
284-
285279
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
286280
class TestExGaussian(BaseTestCases.BaseTestCase):
287281
distribution = pm.ExGaussian
@@ -782,6 +776,19 @@ class TestInverseGammaMuSigma(BaseTestDistribution):
782776
tests_to_run = ["check_pymc_params_match_rv_op"]
783777

784778

779+
class TestChiSquared(BaseTestDistribution):
780+
pymc_dist = pm.ChiSquared
781+
pymc_dist_params = {"nu": 2.0}
782+
expected_rv_op_params = {"nu": 2.0}
783+
reference_dist_params = {"df": 2.0}
784+
reference_dist = seeded_numpy_distribution_builder("chisquare")
785+
tests_to_run = [
786+
"check_pymc_params_match_rv_op",
787+
"check_pymc_draws_match_reference",
788+
"check_rv_size",
789+
]
790+
791+
785792
class TestBinomial(BaseTestDistribution):
786793
pymc_dist = pm.Binomial
787794
pymc_dist_params = {"n": 100, "p": 0.33}

0 commit comments

Comments
 (0)