|
14 | 14 |
|
15 | 15 | """PyMC3 Plotting. |
16 | 16 |
|
17 | | -Plots are delegated to the ArviZ library, a general purpose library for |
18 | | -"exploratory analysis of Bayesian models." See https://arviz-devs.github.io/arviz/ |
19 | | -for details on plots. |
| 17 | +Plots are delegated to the `ArviZ <https://arviz-devs.github.io/arviz/>`_ library, a general purpose library for |
| 18 | +exploratory analysis of Bayesian models. For more details, see https://arviz-devs.github.io/arviz/. |
| 19 | +
|
| 20 | +Only `plot_posterior_predictive_glm` is kept in the PyMC code base for now, but it will move to ArviZ once the latter adds features for regression plots. |
20 | 21 | """ |
21 | 22 | import functools |
22 | 23 | import sys |
23 | 24 | import warnings |
24 | 25 |
|
25 | 26 | import arviz as az |
26 | 27 |
|
27 | | - |
28 | | -def map_args(func): |
29 | | - swaps = [("varnames", "var_names")] |
30 | | - |
31 | | - @functools.wraps(func) |
32 | | - def wrapped(*args, **kwargs): |
33 | | - for (old, new) in swaps: |
34 | | - if old in kwargs and new not in kwargs: |
35 | | - warnings.warn( |
36 | | - f"Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8" |
37 | | - ) |
38 | | - kwargs[new] = kwargs.pop(old) |
39 | | - return func(*args, **kwargs) |
40 | | - |
41 | | - return wrapped |
42 | | - |
43 | | - |
44 | | -# pymc3 custom plots: override these names for custom behavior |
45 | | -autocorrplot = map_args(az.plot_autocorr) |
46 | | -forestplot = map_args(az.plot_forest) |
47 | | -kdeplot = map_args(az.plot_kde) |
48 | | -plot_posterior = map_args(az.plot_posterior) |
49 | | -energyplot = map_args(az.plot_energy) |
50 | | -densityplot = map_args(az.plot_density) |
51 | | -pairplot = map_args(az.plot_pair) |
52 | | - |
53 | | -# Use compact traceplot by default |
54 | | -@map_args |
55 | | -@functools.wraps(az.plot_trace) |
56 | | -def traceplot(*args, **kwargs): |
57 | | - try: |
58 | | - kwargs.setdefault("compact", True) |
59 | | - return az.plot_trace(*args, **kwargs) |
60 | | - except TypeError: |
61 | | - kwargs.pop("compact") |
62 | | - return az.plot_trace(*args, **kwargs) |
63 | | - |
64 | | - |
65 | | -# addition arg mapping for compare plot |
66 | | -@functools.wraps(az.plot_compare) |
67 | | -def compareplot(*args, **kwargs): |
68 | | - if "comp_df" in kwargs: |
69 | | - comp_df = kwargs["comp_df"].copy() |
70 | | - else: |
71 | | - args = list(args) |
72 | | - comp_df = args[0].copy() |
73 | | - if "WAIC" in comp_df.columns: |
74 | | - comp_df = comp_df.rename( |
75 | | - index=str, |
76 | | - columns={ |
77 | | - "WAIC": "waic", |
78 | | - "pWAIC": "p_waic", |
79 | | - "dWAIC": "d_waic", |
80 | | - "SE": "se", |
81 | | - "dSE": "dse", |
82 | | - "var_warn": "warning", |
83 | | - }, |
84 | | - ) |
85 | | - elif "LOO" in comp_df.columns: |
86 | | - comp_df = comp_df.rename( |
87 | | - index=str, |
88 | | - columns={ |
89 | | - "LOO": "loo", |
90 | | - "pLOO": "p_loo", |
91 | | - "dLOO": "d_loo", |
92 | | - "SE": "se", |
93 | | - "dSE": "dse", |
94 | | - "shape_warn": "warning", |
95 | | - }, |
96 | | - ) |
97 | | - if "comp_df" in kwargs: |
98 | | - kwargs["comp_df"] = comp_df |
99 | | - else: |
100 | | - args[0] = comp_df |
101 | | - return az.plot_compare(*args, **kwargs) |
102 | | - |
103 | | - |
104 | 28 | from pymc3.plots.posteriorplot import plot_posterior_predictive_glm |
105 | 29 |
|
106 | | -# Access to arviz plots: base plots provided by arviz |
107 | | -for plot in az.plots.__all__: |
108 | | - setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot))) |
109 | | - |
110 | | -__all__ = tuple(az.plots.__all__) + ( |
111 | | - "autocorrplot", |
112 | | - "compareplot", |
113 | | - "forestplot", |
114 | | - "kdeplot", |
115 | | - "plot_posterior", |
116 | | - "traceplot", |
117 | | - "energyplot", |
118 | | - "densityplot", |
119 | | - "pairplot", |
120 | | - "plot_posterior_predictive_glm", |
121 | | -) |
| 30 | +__all__ = ["plot_posterior_predictive_glm"] |
0 commit comments