|  | 
| 14 | 14 | 
 | 
| 15 | 15 | """PyMC3 Plotting. | 
| 16 | 16 | 
 | 
| 17 |  | -Plots are delegated to the ArviZ library, a general purpose library for | 
| 18 |  | -"exploratory analysis of Bayesian models." See https://arviz-devs.github.io/arviz/ | 
| 19 |  | -for details on plots. | 
|  | 17 | +Plots are delegated to the `ArviZ <https://arviz-devs.github.io/arviz/>`_ library, a general purpose library for | 
|  | 18 | +exploratory analysis of Bayesian models. For more details, see https://arviz-devs.github.io/arviz/. | 
|  | 19 | +
 | 
|  | 20 | +Only `plot_posterior_predictive_glm` is kept in the PyMC code base for now, but it will move to ArviZ once the latter adds features for regression plots. | 
| 20 | 21 | """ | 
| 21 | 22 | import functools | 
| 22 | 23 | import sys | 
| 23 | 24 | import warnings | 
| 24 | 25 | 
 | 
| 25 | 26 | import arviz as az | 
| 26 | 27 | 
 | 
| 27 |  | - | 
| 28 |  | -def map_args(func): | 
| 29 |  | -    swaps = [("varnames", "var_names")] | 
| 30 |  | - | 
| 31 |  | -    @functools.wraps(func) | 
| 32 |  | -    def wrapped(*args, **kwargs): | 
| 33 |  | -        for (old, new) in swaps: | 
| 34 |  | -            if old in kwargs and new not in kwargs: | 
| 35 |  | -                warnings.warn( | 
| 36 |  | -                    f"Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8" | 
| 37 |  | -                ) | 
| 38 |  | -                kwargs[new] = kwargs.pop(old) | 
| 39 |  | -            return func(*args, **kwargs) | 
| 40 |  | - | 
| 41 |  | -    return wrapped | 
| 42 |  | - | 
| 43 |  | - | 
| 44 |  | -# pymc3 custom plots: override these names for custom behavior | 
| 45 |  | -autocorrplot = map_args(az.plot_autocorr) | 
| 46 |  | -forestplot = map_args(az.plot_forest) | 
| 47 |  | -kdeplot = map_args(az.plot_kde) | 
| 48 |  | -plot_posterior = map_args(az.plot_posterior) | 
| 49 |  | -energyplot = map_args(az.plot_energy) | 
| 50 |  | -densityplot = map_args(az.plot_density) | 
| 51 |  | -pairplot = map_args(az.plot_pair) | 
| 52 |  | - | 
| 53 |  | -# Use compact traceplot by default | 
| 54 |  | -@map_args | 
| 55 |  | -@functools.wraps(az.plot_trace) | 
| 56 |  | -def traceplot(*args, **kwargs): | 
| 57 |  | -    try: | 
| 58 |  | -        kwargs.setdefault("compact", True) | 
| 59 |  | -        return az.plot_trace(*args, **kwargs) | 
| 60 |  | -    except TypeError: | 
| 61 |  | -        kwargs.pop("compact") | 
| 62 |  | -        return az.plot_trace(*args, **kwargs) | 
| 63 |  | - | 
| 64 |  | - | 
| 65 |  | -# addition arg mapping for compare plot | 
| 66 |  | -@functools.wraps(az.plot_compare) | 
| 67 |  | -def compareplot(*args, **kwargs): | 
| 68 |  | -    if "comp_df" in kwargs: | 
| 69 |  | -        comp_df = kwargs["comp_df"].copy() | 
| 70 |  | -    else: | 
| 71 |  | -        args = list(args) | 
| 72 |  | -        comp_df = args[0].copy() | 
| 73 |  | -    if "WAIC" in comp_df.columns: | 
| 74 |  | -        comp_df = comp_df.rename( | 
| 75 |  | -            index=str, | 
| 76 |  | -            columns={ | 
| 77 |  | -                "WAIC": "waic", | 
| 78 |  | -                "pWAIC": "p_waic", | 
| 79 |  | -                "dWAIC": "d_waic", | 
| 80 |  | -                "SE": "se", | 
| 81 |  | -                "dSE": "dse", | 
| 82 |  | -                "var_warn": "warning", | 
| 83 |  | -            }, | 
| 84 |  | -        ) | 
| 85 |  | -    elif "LOO" in comp_df.columns: | 
| 86 |  | -        comp_df = comp_df.rename( | 
| 87 |  | -            index=str, | 
| 88 |  | -            columns={ | 
| 89 |  | -                "LOO": "loo", | 
| 90 |  | -                "pLOO": "p_loo", | 
| 91 |  | -                "dLOO": "d_loo", | 
| 92 |  | -                "SE": "se", | 
| 93 |  | -                "dSE": "dse", | 
| 94 |  | -                "shape_warn": "warning", | 
| 95 |  | -            }, | 
| 96 |  | -        ) | 
| 97 |  | -    if "comp_df" in kwargs: | 
| 98 |  | -        kwargs["comp_df"] = comp_df | 
| 99 |  | -    else: | 
| 100 |  | -        args[0] = comp_df | 
| 101 |  | -    return az.plot_compare(*args, **kwargs) | 
| 102 |  | - | 
| 103 |  | - | 
| 104 | 28 | from pymc3.plots.posteriorplot import plot_posterior_predictive_glm | 
| 105 | 29 | 
 | 
| 106 |  | -# Access to arviz plots: base plots provided by arviz | 
| 107 |  | -for plot in az.plots.__all__: | 
| 108 |  | -    setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot))) | 
| 109 |  | - | 
| 110 |  | -__all__ = tuple(az.plots.__all__) + ( | 
| 111 |  | -    "autocorrplot", | 
| 112 |  | -    "compareplot", | 
| 113 |  | -    "forestplot", | 
| 114 |  | -    "kdeplot", | 
| 115 |  | -    "plot_posterior", | 
| 116 |  | -    "traceplot", | 
| 117 |  | -    "energyplot", | 
| 118 |  | -    "densityplot", | 
| 119 |  | -    "pairplot", | 
| 120 |  | -    "plot_posterior_predictive_glm", | 
| 121 |  | -) | 
|  | 30 | +__all__ = ["plot_posterior_predictive_glm"] | 
0 commit comments