@@ -89,39 +89,51 @@ def test_remote_pipe_closed():
8989 pm .sample (step = step , mp_ctx = "spawn" , tune = 2 , draws = 2 , cores = 2 , chains = 2 )
9090
9191
92+ @pytest .mark .xfail (
93+ reason = "Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
94+ )
9295def test_abort ():
9396 with pm .Model () as model :
9497 a = pm .Normal ("a" , shape = 1 )
9598 pm .HalfNormal ("b" )
9699 step1 = pm .NUTS ([a ])
97- step2 = pm .Metropolis ([model . b_log__ ])
100+ step2 = pm .Metropolis ([model [ " b_log__" ] ])
98101
99102 step = pm .CompoundStep ([step1 , step2 ])
100103
101- ctx = multiprocessing .get_context ()
102- proc = ps .ProcessAdapter (
103- 10 ,
104- 10 ,
105- step ,
106- chain = 3 ,
107- seed = 1 ,
108- mp_ctx = ctx ,
109- start = {"a" : 1.0 , "b_log__" : 2.0 },
110- step_method_pickled = None ,
111- pickle_backend = "pickle" ,
112- )
113- proc .start ()
114- proc .write_next ()
115- proc .abort ()
116- proc .join ()
117-
118-
104+ for abort in [False , True ]:
105+ ctx = multiprocessing .get_context ()
106+ proc = ps .ProcessAdapter (
107+ 10 ,
108+ 10 ,
109+ step ,
110+ chain = 3 ,
111+ seed = 1 ,
112+ mp_ctx = ctx ,
113+ start = {"a" : np .array ([1.0 ]), "b_log__" : np .array (2.0 )},
114+ step_method_pickled = None ,
115+ pickle_backend = "pickle" ,
116+ )
117+ proc .start ()
118+ while True :
119+ proc .write_next ()
120+ out = ps .ProcessAdapter .recv_draw ([proc ])
121+ if out [1 ]:
122+ break
123+ if abort :
124+ proc .abort ()
125+ proc .join ()
126+
127+
128+ @pytest .mark .xfail (
129+ reason = "Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
130+ )
119131def test_explicit_sample ():
120132 with pm .Model () as model :
121133 a = pm .Normal ("a" , shape = 1 )
122134 pm .HalfNormal ("b" )
123135 step1 = pm .NUTS ([a ])
124- step2 = pm .Metropolis ([model . b_log__ ])
136+ step2 = pm .Metropolis ([model [ " b_log__" ] ])
125137
126138 step = pm .CompoundStep ([step1 , step2 ])
127139
@@ -133,7 +145,7 @@ def test_explicit_sample():
133145 chain = 3 ,
134146 seed = 1 ,
135147 mp_ctx = ctx ,
136- start = {"a" : 1.0 , "b_log__" : 2.0 },
148+ start = {"a" : np . array ([ 1.0 ]) , "b_log__" : np . array ( 2.0 ) },
137149 step_method_pickled = None ,
138150 pickle_backend = "pickle" ,
139151 )
@@ -149,22 +161,26 @@ def test_explicit_sample():
149161 proc .join ()
150162
151163
164+ @pytest .mark .xfail (
165+ reason = "Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
166+ )
152167def test_iterator ():
153168 with pm .Model () as model :
154169 a = pm .Normal ("a" , shape = 1 )
155170 pm .HalfNormal ("b" )
156171 step1 = pm .NUTS ([a ])
157- step2 = pm .Metropolis ([model . b_log__ ])
172+ step2 = pm .Metropolis ([model [ " b_log__" ] ])
158173
159174 step = pm .CompoundStep ([step1 , step2 ])
160175
161- start = {"a" : 1.0 , "b_log__" : 2.0 }
176+ start = {"a" : np . array ([ 1.0 ]) , "b_log__" : np . array ( 2.0 ) }
162177 sampler = ps .ParallelSampler (10 , 10 , 3 , 2 , [2 , 3 , 4 ], [start ] * 3 , step , 0 , False )
163178 with sampler :
164179 for draw in sampler :
165180 pass
166181
167182
183+ @pytest .mark .xfail (reason = "DensityDist was not yet refactored for v4" )
168184def test_spawn_densitydist_function ():
169185 with pm .Model () as model :
170186 mu = pm .Normal ("mu" , 0 , 1 )
@@ -176,16 +192,19 @@ def func(x):
176192 pm .sample (draws = 10 , tune = 10 , step = pm .Metropolis (), cores = 2 , mp_ctx = "spawn" )
177193
178194
195+ @pytest .mark .xfail (reason = "DensityDist was not yet refactored for v4" )
179196def test_spawn_densitydist_bound_method ():
180197 with pm .Model () as model :
181198 mu = pm .Normal ("mu" , 0 , 1 )
182199 normal_dist = pm .Normal .dist (mu , 1 )
183- obs = pm .DensityDist ("density_dist" , normal_dist .logp , observed = np .random .randn (100 ))
200+ logp = lambda x : pm .logp (normal_dist , x , transformed = False )
201+ obs = pm .DensityDist ("density_dist" , logp , observed = np .random .randn (100 ))
184202 msg = "logp for DensityDist is a bound method, leading to RecursionError while serializing"
185203 with pytest .raises (ValueError , match = msg ):
186204 pm .sample (draws = 10 , tune = 10 , step = pm .Metropolis (), cores = 2 , mp_ctx = "spawn" )
187205
188206
207+ @pytest .mark .xfail (reason = "DensityDist was not yet refactored for v4" )
189208def test_spawn_densitydist_syswarning (monkeypatch ):
190209 monkeypatch .setattr ("pymc3.distributions.distribution.PLATFORM" , "win32" )
191210 with pm .Model () as model :
@@ -195,6 +214,7 @@ def test_spawn_densitydist_syswarning(monkeypatch):
195214 obs = pm .DensityDist ("density_dist" , normal_dist .logp , observed = np .random .randn (100 ))
196215
197216
217+ @pytest .mark .xfail (reason = "DensityDist was not yet refactored for v4" )
198218def test_spawn_densitydist_mpctxwarning (monkeypatch ):
199219 ctx = multiprocessing .get_context ("spawn" )
200220 monkeypatch .setattr (multiprocessing , "get_context" , lambda : ctx )
0 commit comments