@@ -89,39 +89,51 @@ def test_remote_pipe_closed():
89
89
pm .sample (step = step , mp_ctx = "spawn" , tune = 2 , draws = 2 , cores = 2 , chains = 2 )
90
90
91
91
92
+ @pytest .mark .xfail (
93
+ reason = "Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
94
+ )
92
95
def test_abort ():
93
96
with pm .Model () as model :
94
97
a = pm .Normal ("a" , shape = 1 )
95
98
pm .HalfNormal ("b" )
96
99
step1 = pm .NUTS ([a ])
97
- step2 = pm .Metropolis ([model . b_log__ ])
100
+ step2 = pm .Metropolis ([model [ " b_log__" ] ])
98
101
99
102
step = pm .CompoundStep ([step1 , step2 ])
100
103
101
- ctx = multiprocessing .get_context ()
102
- proc = ps .ProcessAdapter (
103
- 10 ,
104
- 10 ,
105
- step ,
106
- chain = 3 ,
107
- seed = 1 ,
108
- mp_ctx = ctx ,
109
- start = {"a" : 1.0 , "b_log__" : 2.0 },
110
- step_method_pickled = None ,
111
- pickle_backend = "pickle" ,
112
- )
113
- proc .start ()
114
- proc .write_next ()
115
- proc .abort ()
116
- proc .join ()
117
-
118
-
104
+ for abort in [False , True ]:
105
+ ctx = multiprocessing .get_context ()
106
+ proc = ps .ProcessAdapter (
107
+ 10 ,
108
+ 10 ,
109
+ step ,
110
+ chain = 3 ,
111
+ seed = 1 ,
112
+ mp_ctx = ctx ,
113
+ start = {"a" : np .array ([1.0 ]), "b_log__" : np .array (2.0 )},
114
+ step_method_pickled = None ,
115
+ pickle_backend = "pickle" ,
116
+ )
117
+ proc .start ()
118
+ while True :
119
+ proc .write_next ()
120
+ out = ps .ProcessAdapter .recv_draw ([proc ])
121
+ if out [1 ]:
122
+ break
123
+ if abort :
124
+ proc .abort ()
125
+ proc .join ()
126
+
127
+
128
+ @pytest .mark .xfail (
129
+ reason = "Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
130
+ )
119
131
def test_explicit_sample ():
120
132
with pm .Model () as model :
121
133
a = pm .Normal ("a" , shape = 1 )
122
134
pm .HalfNormal ("b" )
123
135
step1 = pm .NUTS ([a ])
124
- step2 = pm .Metropolis ([model . b_log__ ])
136
+ step2 = pm .Metropolis ([model [ " b_log__" ] ])
125
137
126
138
step = pm .CompoundStep ([step1 , step2 ])
127
139
@@ -133,7 +145,7 @@ def test_explicit_sample():
133
145
chain = 3 ,
134
146
seed = 1 ,
135
147
mp_ctx = ctx ,
136
- start = {"a" : 1.0 , "b_log__" : 2.0 },
148
+ start = {"a" : np . array ([ 1.0 ]) , "b_log__" : np . array ( 2.0 ) },
137
149
step_method_pickled = None ,
138
150
pickle_backend = "pickle" ,
139
151
)
@@ -149,22 +161,26 @@ def test_explicit_sample():
149
161
proc .join ()
150
162
151
163
164
+ @pytest .mark .xfail (
165
+ reason = "Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
166
+ )
152
167
def test_iterator ():
153
168
with pm .Model () as model :
154
169
a = pm .Normal ("a" , shape = 1 )
155
170
pm .HalfNormal ("b" )
156
171
step1 = pm .NUTS ([a ])
157
- step2 = pm .Metropolis ([model . b_log__ ])
172
+ step2 = pm .Metropolis ([model [ " b_log__" ] ])
158
173
159
174
step = pm .CompoundStep ([step1 , step2 ])
160
175
161
- start = {"a" : 1.0 , "b_log__" : 2.0 }
176
+ start = {"a" : np . array ([ 1.0 ]) , "b_log__" : np . array ( 2.0 ) }
162
177
sampler = ps .ParallelSampler (10 , 10 , 3 , 2 , [2 , 3 , 4 ], [start ] * 3 , step , 0 , False )
163
178
with sampler :
164
179
for draw in sampler :
165
180
pass
166
181
167
182
183
+ @pytest .mark .xfail (reason = "DensityDist was not yet refactored for v4" )
168
184
def test_spawn_densitydist_function ():
169
185
with pm .Model () as model :
170
186
mu = pm .Normal ("mu" , 0 , 1 )
@@ -176,16 +192,19 @@ def func(x):
176
192
pm .sample (draws = 10 , tune = 10 , step = pm .Metropolis (), cores = 2 , mp_ctx = "spawn" )
177
193
178
194
195
+ @pytest .mark .xfail (reason = "DensityDist was not yet refactored for v4" )
179
196
def test_spawn_densitydist_bound_method ():
180
197
with pm .Model () as model :
181
198
mu = pm .Normal ("mu" , 0 , 1 )
182
199
normal_dist = pm .Normal .dist (mu , 1 )
183
- obs = pm .DensityDist ("density_dist" , normal_dist .logp , observed = np .random .randn (100 ))
200
+ logp = lambda x : pm .logp (normal_dist , x , transformed = False )
201
+ obs = pm .DensityDist ("density_dist" , logp , observed = np .random .randn (100 ))
184
202
msg = "logp for DensityDist is a bound method, leading to RecursionError while serializing"
185
203
with pytest .raises (ValueError , match = msg ):
186
204
pm .sample (draws = 10 , tune = 10 , step = pm .Metropolis (), cores = 2 , mp_ctx = "spawn" )
187
205
188
206
207
+ @pytest .mark .xfail (reason = "DensityDist was not yet refactored for v4" )
189
208
def test_spawn_densitydist_syswarning (monkeypatch ):
190
209
monkeypatch .setattr ("pymc3.distributions.distribution.PLATFORM" , "win32" )
191
210
with pm .Model () as model :
@@ -195,6 +214,7 @@ def test_spawn_densitydist_syswarning(monkeypatch):
195
214
obs = pm .DensityDist ("density_dist" , normal_dist .logp , observed = np .random .randn (100 ))
196
215
197
216
217
+ @pytest .mark .xfail (reason = "DensityDist was not yet refactored for v4" )
198
218
def test_spawn_densitydist_mpctxwarning (monkeypatch ):
199
219
ctx = multiprocessing .get_context ("spawn" )
200
220
monkeypatch .setattr (multiprocessing , "get_context" , lambda : ctx )
0 commit comments