Skip to content

Commit 198d13e

Browse files
Add HyperGeometric Distribution to pymc3.distributions.discrete #4108 (#4249)
* Add HyperGeometric distribution to discrete.py; Add tests * Add HyperGeo to distirbutions/__init__.py * Fix minor linting issue * Add ref_rand helper function. Clip lower in logp * Fix bug. Now pymc3_matches_scipy runs without error but pymc3_random_discrete diverges from expected value * passes match with scipy test in test_distributions.py but fails in test_distributions_random.py * Modify HyperGeom.random; Random test still failing. match_with_scipy test passing * rm stray print * Fix failing random test by specifying domain * Update pymc3/distributions/discrete.py Remove stray newline Co-authored-by: Tirth Patel <[email protected]> * Add note in RELEASE-NOTES.md Co-authored-by: Tirth Patel <[email protected]>
1 parent ced05db commit 198d13e

File tree

5 files changed

+146
-0
lines changed

5 files changed

+146
-0
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which,
3535
- Change SMC metropolis kernel to independent metropolis kernel [#4115](https://github.com/pymc-devs/pymc3/pull/4115))
3636
- Add alternative parametrization to NegativeBinomial distribution in terms of n and p (see [#4126](https://github.com/pymc-devs/pymc3/issues/4126))
3737
- Added semantically meaningful `str` representations to PyMC3 objects for console, notebook, and GraphViz use (see [#4076](https://github.com/pymc-devs/pymc3/pull/4076), [#4065](https://github.com/pymc-devs/pymc3/pull/4065), [#4159](https://github.com/pymc-devs/pymc3/pull/4159), [#4217](https://github.com/pymc-devs/pymc3/pull/4217), and [#4243](https://github.com/pymc-devs/pymc3/pull/4243)).
38+
- Add Discrete HyperGeometric Distribution (see [#4249](https://github.com/pymc-devs/pymc3/pull/#4249))
3839

3940
### Maintenance
4041
- Switch the dependency of Theano to our own fork, [Theano-PyMC](https://github.com/pymc-devs/Theano-PyMC).

pymc3/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from .discrete import ZeroInflatedBinomial
6464
from .discrete import DiscreteUniform
6565
from .discrete import Geometric
66+
from .discrete import HyperGeometric
6667
from .discrete import Categorical
6768
from .discrete import OrderedLogistic
6869

@@ -141,6 +142,7 @@
141142
"ZeroInflatedBinomial",
142143
"DiscreteUniform",
143144
"Geometric",
145+
"HyperGeometric",
144146
"Categorical",
145147
"OrderedLogistic",
146148
"DensityDist",

pymc3/distributions/discrete.py

+113
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"ZeroInflatedNegativeBinomial",
3939
"DiscreteUniform",
4040
"Geometric",
41+
"HyperGeometric",
4142
"Categorical",
4243
"OrderedLogistic",
4344
]
@@ -809,6 +810,118 @@ def logp(self, value):
809810
return bound(tt.log(p) + logpow(1 - p, value - 1), 0 <= p, p <= 1, value >= 1)
810811

811812

813+
class HyperGeometric(Discrete):
814+
R"""
815+
Discrete hypergeometric distribution.
816+
817+
The probability of :math:`x` successes in a sequence of :math:`n` bernoulli
818+
trials taken without replacement from a population of :math:`N` objects,
819+
containing :math:`k` good (or successful or Type I) objects.
820+
The pmf of this distribution is
821+
822+
.. math:: f(x \mid N, n, k) = \frac{\binom{k}{x}\binom{N-k}{n-x}}{\binom{N}{n}}
823+
824+
.. plot::
825+
826+
import matplotlib.pyplot as plt
827+
import numpy as np
828+
import scipy.stats as st
829+
plt.style.use('seaborn-darkgrid')
830+
x = np.arange(1, 15)
831+
N = 50
832+
k = 10
833+
for n in [20, 25]:
834+
pmf = st.hypergeom.pmf(x, N, k, n)
835+
plt.plot(x, pmf, '-o', label='n = {}'.format(n))
836+
plt.plot(x, pmf, '-o', label='N = {}'.format(N))
837+
plt.plot(x, pmf, '-o', label='k = {}'.format(k))
838+
plt.xlabel('x', fontsize=12)
839+
plt.ylabel('f(x)', fontsize=12)
840+
plt.legend(loc=1)
841+
plt.show()
842+
843+
======== =============================
844+
Support :math:`x \in \left[\max(0, n - N + k), \min(k, n)\right]`
845+
Mean :math:`\dfrac{nk}{N}`
846+
Variance :math:`\dfrac{(N-n)nk(N-k)}{(N-1)N^2}`
847+
======== =============================
848+
849+
Parameters
850+
----------
851+
N : integer
852+
Total size of the population
853+
k : integer
854+
Number of successful individuals in the population
855+
n : integer
856+
Number of samples drawn from the population
857+
"""
858+
859+
def __init__(self, N, k, n, *args, **kwargs):
860+
super().__init__(*args, **kwargs)
861+
self.N = intX(N)
862+
self.k = intX(k)
863+
self.n = intX(n)
864+
self.mode = intX(tt.floor((n + 1) * (k + 1) / (N + 2)))
865+
866+
def random(self, point=None, size=None):
867+
r"""
868+
Draw random values from HyperGeometric distribution.
869+
870+
Parameters
871+
----------
872+
point : dict, optional
873+
Dict of variable values on which random values are to be
874+
conditioned (uses default point if not specified).
875+
size : int, optional
876+
Desired size of random sample (returns one sample if not
877+
specified).
878+
879+
Returns
880+
-------
881+
array
882+
"""
883+
884+
N, k, n = draw_values([self.N, self.k, self.n], point=point, size=size)
885+
return generate_samples(self._random, N, k, n, dist_shape=self.shape, size=size)
886+
887+
def _random(self, M, n, N, size=None):
888+
r"""Wrapper around scipy stat's hypergeom.rvs"""
889+
try:
890+
samples = stats.hypergeom.rvs(M=M, n=n, N=N, size=size)
891+
return samples
892+
except ValueError:
893+
raise ValueError("Domain error in arguments")
894+
895+
def logp(self, value):
896+
r"""
897+
Calculate log-probability of HyperGeometric distribution at specified value.
898+
899+
Parameters
900+
----------
901+
value : numeric
902+
Value(s) for which log-probability is calculated. If the log probabilities for multiple
903+
values are desired the values must be provided in a numpy array or theano tensor
904+
905+
Returns
906+
-------
907+
TensorVariable
908+
"""
909+
N = self.N
910+
k = self.k
911+
n = self.n
912+
tot, good = N, k
913+
bad = tot - good
914+
result = (
915+
betaln(good + 1, 1)
916+
+ betaln(bad + 1, 1)
917+
+ betaln(tot - n + 1, n + 1)
918+
- betaln(value + 1, good - value + 1)
919+
- betaln(n - value + 1, bad - n + value + 1)
920+
- betaln(tot + 1, 1)
921+
)
922+
return result
923+
924+
812925
class DiscreteUniform(Discrete):
813926
R"""
814927
Discrete uniform distribution.

pymc3/tests/test_distributions.py

+9
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
Rice,
7676
Kumaraswamy,
7777
Moyal,
78+
HyperGeometric,
7879
)
7980

8081
from ..distributions import continuous
@@ -790,6 +791,14 @@ def test_geometric(self):
790791
Geometric, Nat, {"p": Unit}, lambda value, p: np.log(sp.geom.pmf(value, p))
791792
)
792793

794+
def test_hypergeometric(self):
795+
self.pymc3_matches_scipy(
796+
HyperGeometric,
797+
Nat,
798+
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
799+
lambda value, N, k, n: sp.hypergeom.logpmf(value, N, k, n),
800+
)
801+
793802
def test_negative_binomial(self):
794803
def test_fun(value, mu, alpha):
795804
return sp.nbinom.logpmf(value, alpha, 1 - mu / (mu + alpha))

pymc3/tests/test_distributions_random.py

+21
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,11 @@ class TestGeometric(BaseTestCases.BaseTestCase):
507507
params = {"p": 0.5}
508508

509509

510+
class TestHyperGeometric(BaseTestCases.BaseTestCase):
511+
distribution = pm.HyperGeometric
512+
params = {"N": 50, "k": 25, "n": 10}
513+
514+
510515
class TestMoyal(BaseTestCases.BaseTestCase):
511516
distribution = pm.Moyal
512517
params = {"mu": 0.0, "sigma": 1.0}
@@ -739,6 +744,22 @@ def ref_rand(size, alpha, mu):
739744
def test_geometric(self):
740745
pymc3_random_discrete(pm.Geometric, {"p": Unit}, size=500, fails=50, ref_rand=nr.geometric)
741746

747+
def test_hypergeometric(self):
748+
def ref_rand(size, N, k, n):
749+
return st.hypergeom.rvs(M=N, n=k, N=n, size=size)
750+
751+
pymc3_random_discrete(
752+
pm.HyperGeometric,
753+
{
754+
"N": Domain([10, 11, 12, 13], "int64"),
755+
"k": Domain([4, 5, 6, 7], "int64"),
756+
"n": Domain([6, 7, 8, 9], "int64"),
757+
},
758+
size=500,
759+
fails=50,
760+
ref_rand=ref_rand,
761+
)
762+
742763
def test_discrete_uniform(self):
743764
def ref_rand(size, lower, upper):
744765
return st.randint.rvs(lower, upper + 1, size=size)

0 commit comments

Comments
 (0)