@@ -172,24 +172,23 @@ def func(x):
172
172
trace = pm .sample (draws = 10 , tune = 10 , step = pm .Metropolis (), cores = 2 , mp_ctx = "spawn" )
173
173
174
174
175
- @pytest .mark .xfail (raises = ValueError )
176
175
def test_spawn_densitydist_bound_method ():
177
176
with pm .Model () as model :
178
177
mu = pm .Normal ("mu" , 0 , 1 )
179
178
normal_dist = pm .Normal .dist (mu , 1 )
180
179
obs = pm .DensityDist ("density_dist" , normal_dist .logp , observed = np .random .randn (100 ))
181
- trace = pm .sample (draws = 10 , tune = 10 , step = pm .Metropolis (), cores = 2 , mp_ctx = "spawn" )
180
+ msg = "logp for DensityDist is a bound method, leading to RecursionError while serializing"
181
+ with pytest .raises (ValueError , match = msg ):
182
+ trace = pm .sample (draws = 10 , tune = 10 , step = pm .Metropolis (), cores = 2 , mp_ctx = "spawn" )
182
183
183
184
184
- # cannot test this properly: monkeypatching sys.platform messes up Theano
185
- # def test_spawn_densitydist_syswarning(monkeypatch):
186
- # monkeypatch.setattr(sys, "platform", "win32")
187
- # with pm.Model() as model:
188
- # mu = pm.Normal('mu', 0, 1)
189
- # normal_dist = pm.Normal.dist(mu, 1)
190
- # with pytest.warns(UserWarning) as w:
191
- # obs = pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100))
192
- # assert len(w) == 1 and "errors when sampling on platforms" in w[0].message.args[0]
185
+ def test_spawn_densitydist_syswarning (monkeypatch ):
186
+ monkeypatch .setattr ("pymc3.distributions.distribution.PLATFORM" , "win32" )
187
+ with pm .Model () as model :
188
+ mu = pm .Normal ("mu" , 0 , 1 )
189
+ normal_dist = pm .Normal .dist (mu , 1 )
190
+ with pytest .warns (UserWarning , match = "errors when sampling on platforms" ):
191
+ obs = pm .DensityDist ("density_dist" , normal_dist .logp , observed = np .random .randn (100 ))
193
192
194
193
195
194
def test_spawn_densitydist_mpctxwarning (monkeypatch ):
@@ -198,6 +197,5 @@ def test_spawn_densitydist_mpctxwarning(monkeypatch):
198
197
with pm .Model () as model :
199
198
mu = pm .Normal ("mu" , 0 , 1 )
200
199
normal_dist = pm .Normal .dist (mu , 1 )
201
- with pytest .warns (UserWarning ) as w :
200
+ with pytest .warns (UserWarning , match = "errors when sampling when multiprocessing" ) :
202
201
obs = pm .DensityDist ("density_dist" , normal_dist .logp , observed = np .random .randn (100 ))
203
- assert len (w ) == 1 and "errors when sampling when multiprocessing" in w [0 ].message .args [0 ]
0 commit comments