Skip to content

Commit 6cd35fd

Browse files
committed
Modify HyperGeom.random; Random test still failing. match_with_scipy test passing
1 parent f4edd8d commit 6cd35fd

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

pymc3/distributions/discrete.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -881,8 +881,17 @@ def random(self, point=None, size=None):
881881
-------
882882
array
883883
"""
884+
884885
N, k, n = draw_values([self.N, self.k, self.n], point=point, size=size)
885-
return generate_samples(np.random.hypergeometric, N, k, n, dist_shape=self.shape, size=size)
886+
return generate_samples(self._random, N, k, n, dist_shape=self.shape, size=size)
887+
888+
def _random(self, M, n, N, size=None):
889+
r"""Wrapper around scipy stat's hypergeom.rvs"""
890+
try:
891+
samples = stats.hypergeom.rvs(M=M, n=n, N=N, size=size)
892+
return samples
893+
except ValueError:
894+
raise ValueError("Domain error in arguments")
886895

887896
def logp(self, value):
888897
r"""

pymc3/tests/test_distributions_random.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def pymc3_random_discrete(
109109
p = 1.0
110110
else:
111111
_, p = st.chisquare(k[:, 0], k[:, 1])
112+
print(p)
112113
f -= 1
113114
assert p > alpha, str(pt)
114115

@@ -752,7 +753,7 @@ def ref_rand(size, N, k, n):
752753
pm.HyperGeometric,
753754
{"N": Nat, "k": Nat, "n": Nat},
754755
size=500,
755-
fails=100,
756+
fails=50,
756757
ref_rand=ref_rand,
757758
)
758759

0 commit comments

Comments
 (0)