Skip to content

Commit 6ac841c

Browse files
ricardoV94brandonwillard
authored andcommitted
Update several test xfails
1 parent 48e0d98 commit 6ac841c

File tree

3 files changed

+11
-12
lines changed

3 files changed

+11
-12
lines changed

pymc3/tests/test_examples.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_city_data():
5151
return data.merge(unique, "inner", on="fips")
5252

5353

54-
@pytest.mark.xfail(reason="Bernoulli distribution not refactored")
54+
@pytest.mark.xfail(reason="Bernoulli logitp distribution not refactored")
5555
class TestARM5_4(SeededTest):
5656
def build_model(self):
5757
data = pd.read_csv(
@@ -194,7 +194,7 @@ def build_disaster_model(masked=False):
194194

195195

196196
@pytest.mark.xfail(
197-
reason="DiscreteUniform hasn't been refactored"
197+
reason="_check_start_shape fails with start dictionary"
198198
# condition=(aesara.config.floatX == "float32"), reason="Fails on float32"
199199
)
200200
class TestDisasterModel(SeededTest):
@@ -204,9 +204,9 @@ def test_disaster_model(self):
204204
model = build_disaster_model(masked=False)
205205
with model:
206206
# Initial values for stochastic nodes
207-
start = {"early_mean": 2.0, "late_mean": 3.0}
207+
start = {"early_mean": 2, "late_mean": 3.0}
208208
# Use slice sampler for means (other variables auto-selected)
209-
step = pm.Slice([model.early_mean_log__, model.late_mean_log__])
209+
step = pm.Slice([model["early_mean_log__"], model["late_mean_log__"]])
210210
tr = pm.sample(500, tune=50, start=start, step=step, chains=2)
211211
az.summary(tr)
212212

@@ -217,12 +217,12 @@ def test_disaster_model_missing(self):
217217
# Initial values for stochastic nodes
218218
start = {"early_mean": 2.0, "late_mean": 3.0}
219219
# Use slice sampler for means (other variables auto-selected)
220-
step = pm.Slice([model.early_mean_log__, model.late_mean_log__])
220+
step = pm.Slice([model["early_mean_log__"], model["late_mean_log__"]])
221221
tr = pm.sample(500, tune=50, start=start, step=step, chains=2)
222222
az.summary(tr)
223223

224224

225-
@pytest.mark.xfail(reason="ZeroInflatedPoisson hasn't been refactored for v4")
225+
@pytest.mark.xfail(reason="_check_start_shape fails with start dictionary")
226226
class TestLatentOccupancy(SeededTest):
227227
"""
228228
From the PyMC example list
@@ -277,14 +277,14 @@ def test_run(self):
277277
"z": (self.y > 0).astype("int16"),
278278
"theta": np.array(5, dtype="f"),
279279
}
280-
step_one = pm.Metropolis([model.theta_interval__, model.psi_logodds__])
280+
step_one = pm.Metropolis([model["theta_interval__"], model["psi_logodds__"]])
281281
step_two = pm.BinaryMetropolis([model.z])
282282
pm.sample(50, step=[step_one, step_two], start=start, chains=1)
283283

284284

285285
@pytest.mark.xfail(
286-
# condition=(aesara.config.floatX == "float32"),
287-
# reason="Fails on float32 due to starting inf at starting logP",
286+
condition=(aesara.config.floatX == "float32"),
287+
reason="Fails on float32 due to starting inf at starting logP",
288288
)
289289
class TestRSV(SeededTest):
290290
"""
@@ -314,7 +314,7 @@ def build_model(self):
314314
# Prior probability
315315
prev_rsv = pm.Beta("prev_rsv", 1, 5, shape=3)
316316
# RSV in Amman
317-
y_amman = pm.Binomial("y_amman", n_amman, prev_rsv, shape=3, testval=100)
317+
y_amman = pm.Binomial("y_amman", n_amman, prev_rsv, shape=3)
318318
# Likelihood for number with RSV in hospital (assumes Pr(hosp | RSV) = 1)
319319
pm.Binomial("y_hosp", y_amman, market_share, observed=rsv_cases)
320320
return model

pymc3/tests/test_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def test_grad(self):
274274
assert val == 21
275275
npt.assert_allclose(grad, [5, 5, 5, 1, 1, 1, 1, 1, 1])
276276

277-
@pytest.mark.xfail(reason="Lognormal not refactored for v4")
277+
@pytest.mark.xfail(reason="Test not refactored for v4")
278278
def test_edge_case(self):
279279
# Edge case discovered in #2948
280280
ndim = 3

pymc3/tests/test_sampling.py

-1
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,6 @@ def test_shape_edgecase(self):
10401040
prior = pm.sample_prior_predictive(10)
10411041
assert prior["mu"].shape == (10, 5)
10421042

1043-
@pytest.mark.xfail(reason="ZeroInflatedPoisson not refactored for v4")
10441043
def test_zeroinflatedpoisson(self):
10451044
with pm.Model():
10461045
theta = pm.Beta("theta", alpha=1, beta=1)

0 commit comments

Comments
 (0)