Skip to content

Commit 7bb5a33

Browse files
Reenable test_parallel_sampling and test_profile
A few XFAILs were added because DensityDist is not yet refactored and a shape issues causes the CompoundStep to fail.
1 parent 554d7f5 commit 7bb5a33

File tree

3 files changed

+56
-25
lines changed

3 files changed

+56
-25
lines changed

.github/workflows/pytest.yml

+8
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434
--ignore=pymc3/tests/test_model_graph.py
3535
--ignore=pymc3/tests/test_modelcontext.py
3636
--ignore=pymc3/tests/test_parallel_sampling.py
37+
--ignore=pymc3/tests/test_sampling.py
3738
--ignore=pymc3/tests/test_profile.py
3839
--ignore=pymc3/tests/test_step.py
3940
--ignore=pymc3/tests/test_tuning.py
@@ -67,6 +68,10 @@ jobs:
6768
pymc3/tests/test_plots.py
6869
pymc3/tests/test_updates.py
6970
71+
- |
72+
pymc3/tests/test_parallel_sampling.py
73+
pymc3/tests/test_sampling.py
74+
7075
- |
7176
pymc3/tests/test_idata_conversion.py
7277
pymc3/tests/test_distributions_random.py
@@ -78,6 +83,7 @@ jobs:
7883
pymc3/tests/test_model_graph.py
7984
pymc3/tests/test_ode.py
8085
pymc3/tests/test_posdef_sym.py
86+
pymc3/tests/test_profile.py
8187
pymc3/tests/test_quadpotential.py
8288
pymc3/tests/test_shape_handling.py
8389
pymc3/tests/test_step.py
@@ -144,6 +150,7 @@ jobs:
144150
pymc3/tests/test_distributions_random.py
145151
pymc3/tests/test_distributions_timeseries.py
146152
- |
153+
pymc3/tests/test_parallel_sampling.py
147154
pymc3/tests/test_sampling.py
148155
pymc3/tests/test_shared.py
149156
- |
@@ -155,6 +162,7 @@ jobs:
155162
pymc3/tests/test_modelcontext.py
156163
pymc3/tests/test_model_graph.py
157164
pymc3/tests/test_pickling.py
165+
pymc3/tests/test_profile.py
158166
159167
fail-fast: false
160168
runs-on: ${{ matrix.os }}

pymc3/tests/test_parallel_sampling.py

+44-24
Original file line numberDiff line numberDiff line change
@@ -89,39 +89,51 @@ def test_remote_pipe_closed():
8989
pm.sample(step=step, mp_ctx="spawn", tune=2, draws=2, cores=2, chains=2)
9090

9191

92+
@pytest.mark.xfail(
93+
reason="Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
94+
)
9295
def test_abort():
9396
with pm.Model() as model:
9497
a = pm.Normal("a", shape=1)
9598
pm.HalfNormal("b")
9699
step1 = pm.NUTS([a])
97-
step2 = pm.Metropolis([model.b_log__])
100+
step2 = pm.Metropolis([model["b_log__"]])
98101

99102
step = pm.CompoundStep([step1, step2])
100103

101-
ctx = multiprocessing.get_context()
102-
proc = ps.ProcessAdapter(
103-
10,
104-
10,
105-
step,
106-
chain=3,
107-
seed=1,
108-
mp_ctx=ctx,
109-
start={"a": 1.0, "b_log__": 2.0},
110-
step_method_pickled=None,
111-
pickle_backend="pickle",
112-
)
113-
proc.start()
114-
proc.write_next()
115-
proc.abort()
116-
proc.join()
117-
118-
104+
for abort in [False, True]:
105+
ctx = multiprocessing.get_context()
106+
proc = ps.ProcessAdapter(
107+
10,
108+
10,
109+
step,
110+
chain=3,
111+
seed=1,
112+
mp_ctx=ctx,
113+
start={"a": np.array([1.0]), "b_log__": np.array(2.0)},
114+
step_method_pickled=None,
115+
pickle_backend="pickle",
116+
)
117+
proc.start()
118+
while True:
119+
proc.write_next()
120+
out = ps.ProcessAdapter.recv_draw([proc])
121+
if out[1]:
122+
break
123+
if abort:
124+
proc.abort()
125+
proc.join()
126+
127+
128+
@pytest.mark.xfail(
129+
reason="Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
130+
)
119131
def test_explicit_sample():
120132
with pm.Model() as model:
121133
a = pm.Normal("a", shape=1)
122134
pm.HalfNormal("b")
123135
step1 = pm.NUTS([a])
124-
step2 = pm.Metropolis([model.b_log__])
136+
step2 = pm.Metropolis([model["b_log__"]])
125137

126138
step = pm.CompoundStep([step1, step2])
127139

@@ -133,7 +145,7 @@ def test_explicit_sample():
133145
chain=3,
134146
seed=1,
135147
mp_ctx=ctx,
136-
start={"a": 1.0, "b_log__": 2.0},
148+
start={"a": np.array([1.0]), "b_log__": np.array(2.0)},
137149
step_method_pickled=None,
138150
pickle_backend="pickle",
139151
)
@@ -149,22 +161,26 @@ def test_explicit_sample():
149161
proc.join()
150162

151163

164+
@pytest.mark.xfail(
165+
reason="Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701"
166+
)
152167
def test_iterator():
153168
with pm.Model() as model:
154169
a = pm.Normal("a", shape=1)
155170
pm.HalfNormal("b")
156171
step1 = pm.NUTS([a])
157-
step2 = pm.Metropolis([model.b_log__])
172+
step2 = pm.Metropolis([model["b_log__"]])
158173

159174
step = pm.CompoundStep([step1, step2])
160175

161-
start = {"a": 1.0, "b_log__": 2.0}
176+
start = {"a": np.array([1.0]), "b_log__": np.array(2.0)}
162177
sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3, step, 0, False)
163178
with sampler:
164179
for draw in sampler:
165180
pass
166181

167182

183+
@pytest.mark.xfail(reason="DensityDist was not yet refactored for v4")
168184
def test_spawn_densitydist_function():
169185
with pm.Model() as model:
170186
mu = pm.Normal("mu", 0, 1)
@@ -176,16 +192,19 @@ def func(x):
176192
pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
177193

178194

195+
@pytest.mark.xfail(reason="DensityDist was not yet refactored for v4")
179196
def test_spawn_densitydist_bound_method():
180197
with pm.Model() as model:
181198
mu = pm.Normal("mu", 0, 1)
182199
normal_dist = pm.Normal.dist(mu, 1)
183-
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
200+
logp = lambda x: pm.logp(normal_dist, x, transformed=False)
201+
obs = pm.DensityDist("density_dist", logp, observed=np.random.randn(100))
184202
msg = "logp for DensityDist is a bound method, leading to RecursionError while serializing"
185203
with pytest.raises(ValueError, match=msg):
186204
pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
187205

188206

207+
@pytest.mark.xfail(reason="DensityDist was not yet refactored for v4")
189208
def test_spawn_densitydist_syswarning(monkeypatch):
190209
monkeypatch.setattr("pymc3.distributions.distribution.PLATFORM", "win32")
191210
with pm.Model() as model:
@@ -195,6 +214,7 @@ def test_spawn_densitydist_syswarning(monkeypatch):
195214
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
196215

197216

217+
@pytest.mark.xfail(reason="DensityDist was not yet refactored for v4")
198218
def test_spawn_densitydist_mpctxwarning(monkeypatch):
199219
ctx = multiprocessing.get_context("spawn")
200220
monkeypatch.setattr(multiprocessing, "get_context", lambda: ctx)

pymc3/tests/test_profile.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pymc3 as pm
16+
1517
from pymc3.tests.models import simple_model
1618

1719

@@ -23,7 +25,8 @@ def test_profile_model(self):
2325
assert self.model.profile(self.model.logpt).fct_call_time > 0
2426

2527
def test_profile_variable(self):
26-
assert self.model.profile(self.model.value_vars[0].logpt).fct_call_time > 0
28+
rv = self.model.basic_RVs[0]
29+
assert self.model.profile(pm.logpt(rv, self.model.rvs_to_values[rv])).fct_call_time
2730

2831
def test_profile_count(self):
2932
count = 1005

0 commit comments

Comments
 (0)