Skip to content

Commit a5d1697

Browse files
authored
uncomment test from #4297 (#4302)
* uncomment test * 🎨
1 parent 86dc132 commit a5d1697

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

pymc3/distributions/distribution.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
"vectorized_ppc", default=None
5555
) # type: contextvars.ContextVar[Optional[Callable]]
5656

57+
PLATFORM = sys.platform
58+
5759

5860
class _Unpickling:
5961
pass
@@ -510,17 +512,17 @@ def __init__(
510512
super().__init__(shape, dtype, testval, *args, **kwargs)
511513
self.logp = logp
512514
if type(self.logp) == types.MethodType:
513-
if sys.platform != "linux":
515+
if PLATFORM != "linux":
514516
warnings.warn(
515517
"You are passing a bound method as logp for DensityDist, this can lead to "
516-
+ "errors when sampling on platforms other than Linux. Consider using a "
517-
+ "plain function instead, or subclass Distribution."
518+
"errors when sampling on platforms other than Linux. Consider using a "
519+
"plain function instead, or subclass Distribution."
518520
)
519521
elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext:
520522
warnings.warn(
521523
"You are passing a bound method as logp for DensityDist, this can lead to "
522-
+ "errors when sampling when multiprocessing cannot rely on forking. Consider using a "
523-
+ "plain function instead, or subclass Distribution."
524+
"errors when sampling when multiprocessing cannot rely on forking. Consider using a "
525+
"plain function instead, or subclass Distribution."
524526
)
525527
self.rand = random
526528
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape

pymc3/tests/test_parallel_sampling.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -172,24 +172,23 @@ def func(x):
172172
trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
173173

174174

175-
@pytest.mark.xfail(raises=ValueError)
176175
def test_spawn_densitydist_bound_method():
177176
with pm.Model() as model:
178177
mu = pm.Normal("mu", 0, 1)
179178
normal_dist = pm.Normal.dist(mu, 1)
180179
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
181-
trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
180+
msg = "logp for DensityDist is a bound method, leading to RecursionError while serializing"
181+
with pytest.raises(ValueError, match=msg):
182+
trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
182183

183184

184-
# cannot test this properly: monkeypatching sys.platform messes up Theano
185-
# def test_spawn_densitydist_syswarning(monkeypatch):
186-
# monkeypatch.setattr(sys, "platform", "win32")
187-
# with pm.Model() as model:
188-
# mu = pm.Normal('mu', 0, 1)
189-
# normal_dist = pm.Normal.dist(mu, 1)
190-
# with pytest.warns(UserWarning) as w:
191-
# obs = pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100))
192-
# assert len(w) == 1 and "errors when sampling on platforms" in w[0].message.args[0]
185+
def test_spawn_densitydist_syswarning(monkeypatch):
186+
monkeypatch.setattr("pymc3.distributions.distribution.PLATFORM", "win32")
187+
with pm.Model() as model:
188+
mu = pm.Normal("mu", 0, 1)
189+
normal_dist = pm.Normal.dist(mu, 1)
190+
with pytest.warns(UserWarning, match="errors when sampling on platforms"):
191+
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
193192

194193

195194
def test_spawn_densitydist_mpctxwarning(monkeypatch):
@@ -198,6 +197,5 @@ def test_spawn_densitydist_mpctxwarning(monkeypatch):
198197
with pm.Model() as model:
199198
mu = pm.Normal("mu", 0, 1)
200199
normal_dist = pm.Normal.dist(mu, 1)
201-
with pytest.warns(UserWarning) as w:
200+
with pytest.warns(UserWarning, match="errors when sampling when multiprocessing"):
202201
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
203-
assert len(w) == 1 and "errors when sampling when multiprocessing" in w[0].message.args[0]

0 commit comments

Comments
 (0)