Skip to content

Commit 3fa3d1f

Browse files
authored
fix regression #4273 (#4297)
* informative warnings on bound method logp in DensityDist * black * run test on single core only to avoid dill error on windows/macos * adding tests for DensityDist serialize recursion handling * forgot a test
1 parent df3ae60 commit 3fa3d1f

File tree

3 files changed

+74
-7
lines changed

3 files changed

+74
-7
lines changed

pymc3/distributions/distribution.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import multiprocessing
1516
import numbers
1617
import contextvars
1718
import dill
1819
import inspect
20+
import sys
21+
import types
1922
from typing import TYPE_CHECKING
23+
import warnings
2024

2125
if TYPE_CHECKING:
2226
from typing import Optional, Callable
@@ -505,6 +509,19 @@ def __init__(
505509
dtype = theano.config.floatX
506510
super().__init__(shape, dtype, testval, *args, **kwargs)
507511
self.logp = logp
512+
if type(self.logp) == types.MethodType:
513+
if sys.platform != "linux":
514+
warnings.warn(
515+
"You are passing a bound method as logp for DensityDist, this can lead to "
516+
+ "errors when sampling on platforms other than Linux. Consider using a "
517+
+ "plain function instead, or subclass Distribution."
518+
)
519+
elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext:
520+
warnings.warn(
521+
"You are passing a bound method as logp for DensityDist, this can lead to "
522+
+ "errors when sampling when multiprocessing cannot rely on forking. Consider using a "
523+
+ "plain function instead, or subclass Distribution."
524+
)
508525
self.rand = random
509526
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
510527
self.check_shape_in_random = check_shape_in_random
@@ -513,7 +530,15 @@ def __getstate__(self):
513530
# We use dill to serialize the logp function, as this is almost
514531
# always defined in the notebook and won't be pickled correctly.
515532
# Fix https://github.com/pymc-devs/pymc3/issues/3844
516-
logp = dill.dumps(self.logp)
533+
try:
534+
logp = dill.dumps(self.logp)
535+
except RecursionError as err:
536+
if type(self.logp) == types.MethodType:
537+
raise ValueError(
538+
"logp for DensityDist is a bound method, leading to RecursionError while serializing"
539+
) from err
540+
else:
541+
raise err
517542
vals = self.__dict__.copy()
518543
vals["logp"] = logp
519544
return vals

pymc3/tests/test_distributions_random.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,7 @@ def test_density_dist_with_random_sampleable(self, shape):
11711171
shape=shape,
11721172
random=normal_dist.random,
11731173
)
1174-
trace = pm.sample(100)
1174+
trace = pm.sample(100, cores=1)
11751175

11761176
samples = 500
11771177
size = 100
@@ -1194,7 +1194,7 @@ def test_density_dist_with_random_sampleable_failure(self, shape):
11941194
random=normal_dist.random,
11951195
wrap_random_with_dist_shape=False,
11961196
)
1197-
trace = pm.sample(100)
1197+
trace = pm.sample(100, cores=1)
11981198

11991199
samples = 500
12001200
with pytest.raises(RuntimeError):
@@ -1217,7 +1217,7 @@ def test_density_dist_with_random_sampleable_hidden_error(self, shape):
12171217
wrap_random_with_dist_shape=False,
12181218
check_shape_in_random=False,
12191219
)
1220-
trace = pm.sample(100)
1220+
trace = pm.sample(100, cores=1)
12211221

12221222
samples = 500
12231223
ppc = pm.sample_posterior_predictive(trace, samples=samples, model=model)
@@ -1240,7 +1240,7 @@ def test_density_dist_with_random_sampleable_handcrafted_success(self):
12401240
random=rvs,
12411241
wrap_random_with_dist_shape=False,
12421242
)
1243-
trace = pm.sample(100)
1243+
trace = pm.sample(100, cores=1)
12441244

12451245
samples = 500
12461246
size = 100
@@ -1260,7 +1260,7 @@ def test_density_dist_with_random_sampleable_handcrafted_success_fast(self):
12601260
random=rvs,
12611261
wrap_random_with_dist_shape=False,
12621262
)
1263-
trace = pm.sample(100)
1263+
trace = pm.sample(100, cores=1)
12641264

12651265
samples = 500
12661266
size = 100
@@ -1273,7 +1273,7 @@ def test_density_dist_without_random_not_sampleable(self):
12731273
mu = pm.Normal("mu", 0, 1)
12741274
normal_dist = pm.Normal.dist(mu, 1)
12751275
pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
1276-
trace = pm.sample(100)
1276+
trace = pm.sample(100, cores=1)
12771277

12781278
samples = 500
12791279
with pytest.raises(ValueError):

pymc3/tests/test_parallel_sampling.py

+42
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,45 @@ def test_iterator():
159159
with sampler:
160160
for draw in sampler:
161161
pass
162+
163+
164+
def test_spawn_densitydist_function():
165+
with pm.Model() as model:
166+
mu = pm.Normal("mu", 0, 1)
167+
168+
def func(x):
169+
return -2 * (x ** 2).sum()
170+
171+
obs = pm.DensityDist("density_dist", func, observed=np.random.randn(100))
172+
trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
173+
174+
175+
@pytest.mark.xfail(raises=ValueError)
176+
def test_spawn_densitydist_bound_method():
177+
with pm.Model() as model:
178+
mu = pm.Normal("mu", 0, 1)
179+
normal_dist = pm.Normal.dist(mu, 1)
180+
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
181+
trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
182+
183+
184+
# cannot test this properly: monkeypatching sys.platform messes up Theano
185+
# def test_spawn_densitydist_syswarning(monkeypatch):
186+
# monkeypatch.setattr(sys, "platform", "win32")
187+
# with pm.Model() as model:
188+
# mu = pm.Normal('mu', 0, 1)
189+
# normal_dist = pm.Normal.dist(mu, 1)
190+
# with pytest.warns(UserWarning) as w:
191+
# obs = pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100))
192+
# assert len(w) == 1 and "errors when sampling on platforms" in w[0].message.args[0]
193+
194+
195+
def test_spawn_densitydist_mpctxwarning(monkeypatch):
196+
ctx = multiprocessing.get_context("spawn")
197+
monkeypatch.setattr(multiprocessing, "get_context", lambda: ctx)
198+
with pm.Model() as model:
199+
mu = pm.Normal("mu", 0, 1)
200+
normal_dist = pm.Normal.dist(mu, 1)
201+
with pytest.warns(UserWarning) as w:
202+
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
203+
assert len(w) == 1 and "errors when sampling when multiprocessing" in w[0].message.args[0]

0 commit comments

Comments
 (0)