Skip to content

Commit 8e54fc9

Browse files
michaelosthegericardoV94
authored andcommitted
Monkeypatch instructions for migrating away from the old rv.logp API
1 parent 55d455a commit 8e54fc9

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

pymc3/distributions/distribution.py

+13
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ def get_moment(op, rv, size, *rv_inputs):
122122
return new_cls
123123

124124

125+
def _make_nice_attr_error(oldcode: str, newcode: str):
126+
def fn(*args, **kwargs):
127+
raise AttributeError(f"The `{oldcode}` method was removed. Instead use `{newcode}`.`")
128+
129+
return fn
130+
131+
125132
class Distribution(metaclass=DistributionMeta):
126133
"""Statistical distribution"""
127134

@@ -243,6 +250,9 @@ def __new__(
243250
functools.partial(str_for_dist, formatting="latex"), rv_out
244251
)
245252

253+
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
254+
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
255+
rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
246256
return rv_out
247257

248258
@classmethod
@@ -333,6 +343,9 @@ def dist(
333343
rv_out.update = (rng, new_rng)
334344
rng.default_update = new_rng
335345

346+
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
347+
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
348+
rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
336349
return rv_out
337350

338351

pymc3/tests/test_distributions.py

+25
Original file line numberDiff line numberDiff line change
@@ -3262,3 +3262,28 @@ def test_distinct_rvs():
32623262
pp_samples_2 = pm.sample_prior_predictive(samples=2)
32633263

32643264
assert np.array_equal(pp_samples["y"], pp_samples_2["y"])
3265+
3266+
3267+
@pytest.mark.parametrize(
3268+
"method,newcode",
3269+
[
3270+
("logp", r"pm.logp\(rv, x\)"),
3271+
("logcdf", r"pm.logcdf\(rv, x\)"),
3272+
("random", r"rv.eval\(\)"),
3273+
],
3274+
)
3275+
def test_logp_gives_migration_instructions(method, newcode):
3276+
rv = pm.Normal.dist()
3277+
f = getattr(rv, method)
3278+
with pytest.raises(AttributeError, match=rf"use `{newcode}`"):
3279+
f()
3280+
3281+
# A dim-induced resize of the rv created by the `.dist()` API,
3282+
# happening in Distribution.__new__ would make us loose the monkeypatches.
3283+
# So this triggers it to test if the monkeypatch still works.
3284+
with pm.Model(coords={"year": [2019, 2021, 2022]}):
3285+
rv = pm.Normal("n", dims="year")
3286+
f = getattr(rv, method)
3287+
with pytest.raises(AttributeError, match=rf"use `{newcode}`"):
3288+
f()
3289+
pass

0 commit comments

Comments
 (0)