Skip to content

Commit 8076612

Browse files
🎉 Start removing Diagnostics & Plots in PyMC3 development
🔥 Remove arviz plots * Remove directly imported arviz plots used in pymc3 plots 🔥 Remove all plots from PyMC3 Plots module 🔥 Remove PyMC3 plots references in Docs 🎨 Mention Plotting & Diagnostics in API page and remove plots reference in __init__.py ⏪ Revert posterior_plot function, test, and docs 🎨 Add deprecation warning to posterior_plot function 🎨 Add context on plot import and import back into __init__.py ✏️ Add warning and details of posterior_plot added Update docs/source/api/plots.rst Co-authored-by: Alexandre ANDORRA <[email protected]> Update docs/source/api/plots.rst Co-authored-by: Alexandre ANDORRA <[email protected]> Update pymc3/plots/__init__.py Co-authored-by: Alexandre ANDORRA <[email protected]> Update pymc3/plots/__init__.py Co-authored-by: Alexandre ANDORRA <[email protected]> ✏️ Update docs to add stats.rst details ✏️ Minor docs notation for posterioplot function(s) 📝 Add breakline before docstring title
1 parent 2824027 commit 8076612

File tree

12 files changed

+52
-221
lines changed

12 files changed

+52
-221
lines changed

benchmarks/benchmarks/benchmarks.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
import time
1615
import timeit
1716

17+
import arviz as az
1818
import numpy as np
1919
import pandas as pd
2020
import theano
@@ -192,7 +192,7 @@ def track_glm_hierarchical_ess(self, init):
192192
compute_convergence_checks=False,
193193
)
194194
tot = time.time() - t0
195-
ess = float(pm.ess(trace, var_names=["mu_a"])["mu_a"].values)
195+
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
196196
return ess / tot
197197

198198
def track_marginal_mixture_model_ess(self, init):
@@ -214,7 +214,7 @@ def track_marginal_mixture_model_ess(self, init):
214214
compute_convergence_checks=False,
215215
)
216216
tot = time.time() - t0
217-
ess = pm.ess(trace, var_names=["mu"])["mu"].values.min() # worst case
217+
ess = az.ess(trace, var_names=["mu"])["mu"].values.min() # worst case
218218
return ess / tot
219219

220220

@@ -245,7 +245,7 @@ def track_glm_hierarchical_ess(self, step):
245245
compute_convergence_checks=False,
246246
)
247247
tot = time.time() - t0
248-
ess = float(pm.ess(trace, var_names=["mu_a"])["mu_a"].values)
248+
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
249249
return ess / tot
250250

251251

@@ -304,7 +304,7 @@ def freefall(y, t, p):
304304
t0 = time.time()
305305
trace = pm.sample(500, tune=1000, chains=2, cores=2, random_seed=0)
306306
tot = time.time() - t0
307-
ess = pm.ess(trace)
307+
ess = az.ess(trace)
308308
return np.mean([ess.sigma, ess.gamma]) / tot
309309

310310

docs/source/api.rst

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ API Reference
2828
api/shape_utils
2929
api/ode
3030

31-
3231
Indices and tables
3332
===================
3433

docs/source/api/plots.rst

+5-12
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,9 @@ Plots are delegated to the
88
`ArviZ <https://arviz-devs.github.io/arviz/index.html>`_.
99
library, a general purpose library for
1010
"exploratory analysis of Bayesian models."
11-
For plots, ``pymc3.<function>`` are now aliases
12-
for ArviZ functions. Thus, the links below will redirect you to
13-
ArviZ docs:
11+
Refer to its documentation to use the plotting functions directly.
1412

15-
- :func:`pymc3.traceplot <arviz:arviz.plot_trace>`
16-
- :func:`pymc3.plot_posterior <arviz:arviz.plot_posterior>`
17-
- :func:`pymc3.forestplot <arviz:arviz.plot_forest>`
18-
- :func:`pymc3.compareplot <arviz:arviz.plot_compare>`
19-
- :func:`pymc3.autocorrplot <arviz:arviz.plot_autocorr>`
20-
- :func:`pymc3.energyplot <arviz:arviz.plot_energy>`
21-
- :func:`pymc3.kdeplot <arviz:arviz.plot_kde>`
22-
- :func:`pymc3.densityplot <arviz:arviz.plot_density>`
23-
- :func:`pymc3.pairplot <arviz:arviz.plot_pair>`
13+
.. currentmodule:: pymc3.plots.posteriorplot
14+
15+
.. automodule:: pymc3.plots.posteriorplot
16+
:members:

docs/source/api/stats.rst

+1-18
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,4 @@ Statistics and diagnostics are delegated to the
55
`ArviZ <https://arviz-devs.github.io/arviz/index.html>`_.
66
library, a general purpose library for
77
"exploratory analysis of Bayesian models."
8-
For statistics and diagnostics, ``pymc3.<function>`` are now aliases
9-
for ArviZ functions. Thus, the links below will redirect you to
10-
ArviZ docs:
11-
12-
.. currentmodule:: pymc3.stats
13-
14-
15-
- :func:`pymc3.bfmi <arviz:arviz.bfmi>`
16-
- :func:`pymc3.compare <arviz:arviz.compare>`
17-
- :func:`pymc3.ess <arviz:arviz.ess>`
18-
- :data:`pymc3.geweke <arviz:arviz.geweke>`
19-
- :func:`pymc3.hpd <arviz:arviz.hpd>`
20-
- :func:`pymc3.loo <arviz:arviz.loo>`
21-
- :func:`pymc3.mcse <arviz:arviz.mcse>`
22-
- :func:`pymc3.r2_score <arviz:arviz.r2_score>`
23-
- :func:`pymc3.rhat <arviz:arviz.rhat>`
24-
- :func:`pymc3.summary <arviz:arviz.summary>`
25-
- :func:`pymc3.waic <arviz:arviz.waic>`
8+
Refer to its documentation to use the diagnostics functions directly.

pymc3/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def __set_compiler_flags():
6161
from pymc3.plots import *
6262
from pymc3.sampling import *
6363
from pymc3.smc import *
64-
from pymc3.stats import *
6564
from pymc3.step_methods import *
6665
from pymc3.tests import test
6766
from pymc3.theanof import *

pymc3/plots/__init__.py

+5-96
Original file line numberDiff line numberDiff line change
@@ -14,108 +14,17 @@
1414

1515
"""PyMC3 Plotting.
1616
17-
Plots are delegated to the ArviZ library, a general purpose library for
18-
"exploratory analysis of Bayesian models." See https://arviz-devs.github.io/arviz/
19-
for details on plots.
17+
Plots are delegated to the `ArviZ <https://arviz-devs.github.io/arviz/>`_ library, a general purpose library for
18+
exploratory analysis of Bayesian models. For more details, see https://arviz-devs.github.io/arviz/.
19+
20+
Only `plot_posterior_predictive_glm` is kept in the PyMC code base for now, but it will move to ArviZ once the latter adds features for regression plots.
2021
"""
2122
import functools
2223
import sys
2324
import warnings
2425

2526
import arviz as az
2627

27-
28-
def map_args(func):
29-
swaps = [("varnames", "var_names")]
30-
31-
@functools.wraps(func)
32-
def wrapped(*args, **kwargs):
33-
for (old, new) in swaps:
34-
if old in kwargs and new not in kwargs:
35-
warnings.warn(
36-
f"Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8"
37-
)
38-
kwargs[new] = kwargs.pop(old)
39-
return func(*args, **kwargs)
40-
41-
return wrapped
42-
43-
44-
# pymc3 custom plots: override these names for custom behavior
45-
autocorrplot = map_args(az.plot_autocorr)
46-
forestplot = map_args(az.plot_forest)
47-
kdeplot = map_args(az.plot_kde)
48-
plot_posterior = map_args(az.plot_posterior)
49-
energyplot = map_args(az.plot_energy)
50-
densityplot = map_args(az.plot_density)
51-
pairplot = map_args(az.plot_pair)
52-
53-
# Use compact traceplot by default
54-
@map_args
55-
@functools.wraps(az.plot_trace)
56-
def traceplot(*args, **kwargs):
57-
try:
58-
kwargs.setdefault("compact", True)
59-
return az.plot_trace(*args, **kwargs)
60-
except TypeError:
61-
kwargs.pop("compact")
62-
return az.plot_trace(*args, **kwargs)
63-
64-
65-
# addition arg mapping for compare plot
66-
@functools.wraps(az.plot_compare)
67-
def compareplot(*args, **kwargs):
68-
if "comp_df" in kwargs:
69-
comp_df = kwargs["comp_df"].copy()
70-
else:
71-
args = list(args)
72-
comp_df = args[0].copy()
73-
if "WAIC" in comp_df.columns:
74-
comp_df = comp_df.rename(
75-
index=str,
76-
columns={
77-
"WAIC": "waic",
78-
"pWAIC": "p_waic",
79-
"dWAIC": "d_waic",
80-
"SE": "se",
81-
"dSE": "dse",
82-
"var_warn": "warning",
83-
},
84-
)
85-
elif "LOO" in comp_df.columns:
86-
comp_df = comp_df.rename(
87-
index=str,
88-
columns={
89-
"LOO": "loo",
90-
"pLOO": "p_loo",
91-
"dLOO": "d_loo",
92-
"SE": "se",
93-
"dSE": "dse",
94-
"shape_warn": "warning",
95-
},
96-
)
97-
if "comp_df" in kwargs:
98-
kwargs["comp_df"] = comp_df
99-
else:
100-
args[0] = comp_df
101-
return az.plot_compare(*args, **kwargs)
102-
103-
10428
from pymc3.plots.posteriorplot import plot_posterior_predictive_glm
10529

106-
# Access to arviz plots: base plots provided by arviz
107-
for plot in az.plots.__all__:
108-
setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot)))
109-
110-
__all__ = tuple(az.plots.__all__) + (
111-
"autocorrplot",
112-
"compareplot",
113-
"forestplot",
114-
"kdeplot",
115-
"plot_posterior",
116-
"traceplot",
117-
"energyplot",
118-
"densityplot",
119-
"pairplot",
120-
"plot_posterior_predictive_glm",
121-
)
30+
__all__ = ["plot_posterior_predictive_glm"]

pymc3/plots/posteriorplot.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import annotations
1616

17+
import warnings
18+
1719
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1820

1921
import matplotlib.pyplot as plt
@@ -33,20 +35,33 @@ def plot_posterior_predictive_glm(
3335
**kwargs: Any
3436
) -> None:
3537
"""Plot posterior predictive of a linear model.
36-
:Arguments:
37-
trace: InferenceData or MultiTrace
38-
Output of pm.sample()
39-
eval: <array>
40-
Array over which to evaluate lm
41-
lm: function <default: linear function>
42-
Function mapping parameters at different points
43-
to their respective outputs.
44-
input: point, sample
45-
output: estimated value
46-
samples: int <default=30>
47-
How many posterior samples to draw.
48-
Additional keyword arguments are passed to pylab.plot().
38+
39+
Parameters
40+
----------
41+
trace: InferenceData or MultiTrace
42+
Output of pm.sample()
43+
eval: <array>
44+
Array over which to evaluate lm
45+
lm: function <default: linear function>
46+
Function mapping parameters at different points
47+
to their respective outputs.
48+
input: point, sample
49+
output: estimated value
50+
samples: int <default=30>
51+
How many posterior samples to draw.
52+
kwargs : mapping, optional
53+
Additional keyword arguments are passed to ``matplotlib.pyplot.plot()``.
54+
55+
Warnings
56+
--------
57+
The `plot_posterior_predictive_glm` function will removed in a future PyMC3 release.
4958
"""
59+
warnings.warn(
60+
"The `plot_posterior_predictive_glm` function will migrate to Arviz in a future release. "
61+
"\nKeep up to date with `ArviZ <https://arviz-devs.github.io/arviz/>`_ for future updates.",
62+
DeprecationWarning,
63+
)
64+
5065
if lm is None:
5166
lm = lambda x, sample: sample["Intercept"] + sample["x"] * x
5267

pymc3/sampling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def sample(
416416
...: y = pm.Binomial("y", n=n, p=p, observed=h)
417417
...: trace = pm.sample()
418418
419-
In [3]: pm.summary(trace, kind="stats")
419+
In [3]: az.summary(trace, kind="stats")
420420
421421
Out[3]:
422422
mean sd hdi_3% hdi_97%

pymc3/stats/__init__.py

-69
This file was deleted.

pymc3/step_methods/mlda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ class MLDA(ArrayStepShared):
333333
... tune=100, step=step_method,
334334
... random_seed=123)
335335
...
336-
... pm.summary(trace, kind="stats")
336+
... az.summary(trace, kind="stats")
337337
mean sd hdi_3% hdi_97%
338338
x 0.99 0.987 -0.734 2.992
339339

pymc3/tests/sampler_fixtures.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import arviz as az
1516
import numpy as np
1617
import numpy.testing as npt
1718
import theano.tensor as tt
@@ -146,12 +147,12 @@ def setup_class(cls):
146147

147148
def test_neff(self):
148149
if hasattr(self, "min_n_eff"):
149-
n_eff = pm.ess(self.trace[self.burn :])
150+
n_eff = az.ess(self.trace[self.burn :])
150151
for var in n_eff:
151152
npt.assert_array_less(self.min_n_eff, n_eff[var])
152153

153154
def test_Rhat(self):
154-
rhat = pm.rhat(self.trace[self.burn :])
155+
rhat = az.rhat(self.trace[self.burn :])
155156
for var in rhat:
156157
npt.assert_allclose(rhat[var], 1, rtol=0.01)
157158

0 commit comments

Comments
 (0)