Skip to content

Commit 7e0ca34

Browse files
authored
Merge pull request #438 from lpoug/plot-data
Get plot data for prepostfit experiments
2 parents 2a6f9db + da6c91d commit 7e0ca34

13 files changed

+233
-21
lines changed

causalpy/experiments/base.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from abc import abstractmethod
1919

20+
import pandas as pd
2021
from sklearn.base import RegressorMixin
2122

2223
from causalpy.pymc_models import PyMCModel
@@ -59,22 +60,45 @@ def print_coefficients(self, round_to=None):
5960
def plot(self, *args, **kwargs) -> tuple:
6061
"""Plot the model.
6162
62-
Internally, this function dispatches to either `bayesian_plot` or `ols_plot`
63+
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
6364
depending on the model type.
6465
"""
6566
if isinstance(self.model, PyMCModel):
66-
return self.bayesian_plot(*args, **kwargs)
67+
return self._bayesian_plot(*args, **kwargs)
6768
elif isinstance(self.model, RegressorMixin):
68-
return self.ols_plot(*args, **kwargs)
69+
return self._ols_plot(*args, **kwargs)
6970
else:
7071
raise ValueError("Unsupported model type")
7172

7273
@abstractmethod
73-
def bayesian_plot(self, *args, **kwargs):
74+
def _bayesian_plot(self, *args, **kwargs):
7475
"""Abstract method for plotting the model."""
75-
raise NotImplementedError("bayesian_plot method not yet implemented")
76+
raise NotImplementedError("_bayesian_plot method not yet implemented")
7677

7778
@abstractmethod
78-
def ols_plot(self, *args, **kwargs):
79+
def _ols_plot(self, *args, **kwargs):
7980
"""Abstract method for plotting the model."""
80-
raise NotImplementedError("ols_plot method not yet implemented")
81+
raise NotImplementedError("_ols_plot method not yet implemented")
82+
83+
def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
84+
"""Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
85+
86+
Internally, this function dispatches to either :func:`get_plot_data_bayesian` or :func:`get_plot_data_ols`
87+
depending on the model type.
88+
"""
89+
if isinstance(self.model, PyMCModel):
90+
return self.get_plot_data_bayesian(*args, **kwargs)
91+
elif isinstance(self.model, RegressorMixin):
92+
return self.get_plot_data_ols(*args, **kwargs)
93+
else:
94+
raise ValueError("Unsupported model type")
95+
96+
@abstractmethod
97+
def get_plot_data_bayesian(self, *args, **kwargs):
98+
"""Abstract method for recovering plot data."""
99+
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")
100+
101+
@abstractmethod
102+
def get_plot_data_ols(self, *args, **kwargs):
103+
"""Abstract method for recovering plot data."""
104+
raise NotImplementedError("get_plot_data_ols method not yet implemented")

causalpy/experiments/diff_in_diff.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _causal_impact_summary_stat(self, round_to=None) -> str:
229229
"""Computes the mean and 94% credible interval bounds for the causal impact."""
230230
return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}"
231231

232-
def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
232+
def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
233233
"""
234234
Plot the results
235235
@@ -367,7 +367,7 @@ def _plot_causal_impact_arrow(results, ax):
367367
)
368368
return fig, ax
369369

370-
def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
370+
def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
371371
"""Generate plot for difference-in-differences"""
372372
round_to = kwargs.get("round_to")
373373
fig, ax = plt.subplots()

causalpy/experiments/prepostfit.py

+69-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from sklearn.base import RegressorMixin
2626

2727
from causalpy.custom_exceptions import BadIndexException
28-
from causalpy.plot_utils import plot_xY
28+
from causalpy.plot_utils import get_hdi_to_df, plot_xY
2929
from causalpy.pymc_models import PyMCModel
3030
from causalpy.utils import round_num
3131

@@ -123,7 +123,7 @@ def summary(self, round_to=None) -> None:
123123
print(f"Formula: {self.formula}")
124124
self.print_coefficients(round_to)
125125

126-
def bayesian_plot(
126+
def _bayesian_plot(
127127
self, round_to=None, **kwargs
128128
) -> tuple[plt.Figure, List[plt.Axes]]:
129129
"""
@@ -231,7 +231,7 @@ def bayesian_plot(
231231

232232
return fig, ax
233233

234-
def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
234+
def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
235235
"""
236236
Plot the results
237237
@@ -303,6 +303,70 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303303

304304
return (fig, ax)
305305

306+
def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
307+
"""
308+
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
309+
310+
:param hdi_prob:
311+
Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function.
312+
"""
313+
if isinstance(self.model, PyMCModel):
314+
hdi_pct = int(round(hdi_prob * 100))
315+
316+
pred_lower_col = f"pred_hdi_lower_{hdi_pct}"
317+
pred_upper_col = f"pred_hdi_upper_{hdi_pct}"
318+
impact_lower_col = f"impact_hdi_lower_{hdi_pct}"
319+
impact_upper_col = f"impact_hdi_upper_{hdi_pct}"
320+
321+
pre_data = self.datapre.copy()
322+
post_data = self.datapost.copy()
323+
324+
pre_data["prediction"] = (
325+
az.extract(self.pre_pred, group="posterior_predictive", var_names="mu")
326+
.mean("sample")
327+
.values
328+
)
329+
post_data["prediction"] = (
330+
az.extract(self.post_pred, group="posterior_predictive", var_names="mu")
331+
.mean("sample")
332+
.values
333+
)
334+
pre_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
335+
self.pre_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
336+
).set_index(pre_data.index)
337+
post_data[[pred_lower_col, pred_upper_col]] = get_hdi_to_df(
338+
self.post_pred["posterior_predictive"].mu, hdi_prob=hdi_prob
339+
).set_index(post_data.index)
340+
341+
pre_data["impact"] = self.pre_impact.mean(dim=["chain", "draw"]).values
342+
post_data["impact"] = self.post_impact.mean(dim=["chain", "draw"]).values
343+
pre_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
344+
self.pre_impact, hdi_prob=hdi_prob
345+
).set_index(pre_data.index)
346+
post_data[[impact_lower_col, impact_upper_col]] = get_hdi_to_df(
347+
self.post_impact, hdi_prob=hdi_prob
348+
).set_index(post_data.index)
349+
350+
self.plot_data = pd.concat([pre_data, post_data])
351+
352+
return self.plot_data
353+
else:
354+
raise ValueError("Unsupported model type")
355+
356+
def get_plot_data_ols(self) -> pd.DataFrame:
357+
"""
358+
Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
359+
"""
360+
pre_data = self.datapre.copy()
361+
post_data = self.datapost.copy()
362+
pre_data["prediction"] = self.pre_pred
363+
post_data["prediction"] = self.post_pred
364+
pre_data["impact"] = self.pre_impact
365+
post_data["impact"] = self.post_impact
366+
self.plot_data = pd.concat([pre_data, post_data])
367+
368+
return self.plot_data
369+
306370

307371
class InterruptedTimeSeries(PrePostFit):
308372
"""
@@ -382,7 +446,7 @@ class SyntheticControl(PrePostFit):
382446
supports_ols = True
383447
supports_bayes = True
384448

385-
def bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
449+
def _bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
386450
"""
387451
Plot the results
388452
@@ -393,7 +457,7 @@ def bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
393457
Whether to plot the control units as well. Defaults to False.
394458
"""
395459
# call the super class method
396-
fig, ax = super().bayesian_plot(*args, **kwargs)
460+
fig, ax = super()._bayesian_plot(*args, **kwargs)
397461

398462
# additional plotting functionality for the synthetic control experiment
399463
plot_predictors = kwargs.get("plot_predictors", False)

causalpy/experiments/prepostnegd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def summary(self, round_to=None) -> None:
200200
print(self._causal_impact_summary_stat(round_to))
201201
self.print_coefficients(round_to)
202202

203-
def bayesian_plot(
203+
def _bayesian_plot(
204204
self, round_to=None, **kwargs
205205
) -> tuple[plt.Figure, List[plt.Axes]]:
206206
"""Generate plot for ANOVA-like experiments with non-equivalent group designs."""

causalpy/experiments/regression_discontinuity.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def summary(self, round_to=None) -> None:
218218
print("\n")
219219
self.print_coefficients(round_to)
220220

221-
def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
221+
def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
222222
"""Generate plot for regression discontinuity designs."""
223223
fig, ax = plt.subplots()
224224
# Plot raw data
@@ -267,7 +267,7 @@ def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
267267
)
268268
return (fig, ax)
269269

270-
def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
270+
def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
271271
"""Generate plot for regression discontinuity designs."""
272272
fig, ax = plt.subplots()
273273
# Plot raw data

causalpy/experiments/regression_kink.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def summary(self, round_to=None) -> None:
189189
)
190190
self.print_coefficients(round_to)
191191

192-
def bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
192+
def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
193193
"""Generate plot for regression kink designs."""
194194
fig, ax = plt.subplots()
195195
# Plot raw data

causalpy/plot_utils.py

+21
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,24 @@ def plot_xY(
7979
filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children())
8080
)[-1]
8181
return (h_line, h_patch)
82+
83+
84+
def get_hdi_to_df(
85+
x: xr.DataArray,
86+
hdi_prob: float = 0.94,
87+
) -> pd.DataFrame:
88+
"""
89+
Utility function to calculate and recover HDI intervals.
90+
91+
:param x:
92+
Xarray data array
93+
:param hdi_prob:
94+
The size of the HDI, default is 0.94
95+
"""
96+
hdi = (
97+
az.hdi(x, hdi_prob=hdi_prob)
98+
.to_dataframe()
99+
.unstack(level="hdi")
100+
.droplevel(0, axis=1)
101+
)
102+
return hdi

0 commit comments

Comments
 (0)