25
25
from sklearn .base import RegressorMixin
26
26
27
27
from causalpy .custom_exceptions import BadIndexException
28
- from causalpy .plot_utils import plot_xY
28
+ from causalpy .plot_utils import get_hdi_to_df , plot_xY
29
29
from causalpy .pymc_models import PyMCModel
30
30
from causalpy .utils import round_num
31
31
@@ -123,7 +123,7 @@ def summary(self, round_to=None) -> None:
123
123
print (f"Formula: { self .formula } " )
124
124
self .print_coefficients (round_to )
125
125
126
- def bayesian_plot (
126
+ def _bayesian_plot (
127
127
self , round_to = None , ** kwargs
128
128
) -> tuple [plt .Figure , List [plt .Axes ]]:
129
129
"""
@@ -231,7 +231,7 @@ def bayesian_plot(
231
231
232
232
return fig , ax
233
233
234
- def ols_plot (self , round_to = None , ** kwargs ) -> tuple [plt .Figure , List [plt .Axes ]]:
234
+ def _ols_plot (self , round_to = None , ** kwargs ) -> tuple [plt .Figure , List [plt .Axes ]]:
235
235
"""
236
236
Plot the results
237
237
@@ -303,6 +303,70 @@ def ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]
303
303
304
304
return (fig , ax )
305
305
306
+ def get_plot_data_bayesian (self , hdi_prob : float = 0.94 ) -> pd .DataFrame :
307
+ """
308
+ Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
309
+
310
+ :param hdi_prob:
311
+ Prob for which the highest density interval will be computed. The default value is defined as the default from the :func:`arviz.hdi` function.
312
+ """
313
+ if isinstance (self .model , PyMCModel ):
314
+ hdi_pct = int (round (hdi_prob * 100 ))
315
+
316
+ pred_lower_col = f"pred_hdi_lower_{ hdi_pct } "
317
+ pred_upper_col = f"pred_hdi_upper_{ hdi_pct } "
318
+ impact_lower_col = f"impact_hdi_lower_{ hdi_pct } "
319
+ impact_upper_col = f"impact_hdi_upper_{ hdi_pct } "
320
+
321
+ pre_data = self .datapre .copy ()
322
+ post_data = self .datapost .copy ()
323
+
324
+ pre_data ["prediction" ] = (
325
+ az .extract (self .pre_pred , group = "posterior_predictive" , var_names = "mu" )
326
+ .mean ("sample" )
327
+ .values
328
+ )
329
+ post_data ["prediction" ] = (
330
+ az .extract (self .post_pred , group = "posterior_predictive" , var_names = "mu" )
331
+ .mean ("sample" )
332
+ .values
333
+ )
334
+ pre_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
335
+ self .pre_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
336
+ ).set_index (pre_data .index )
337
+ post_data [[pred_lower_col , pred_upper_col ]] = get_hdi_to_df (
338
+ self .post_pred ["posterior_predictive" ].mu , hdi_prob = hdi_prob
339
+ ).set_index (post_data .index )
340
+
341
+ pre_data ["impact" ] = self .pre_impact .mean (dim = ["chain" , "draw" ]).values
342
+ post_data ["impact" ] = self .post_impact .mean (dim = ["chain" , "draw" ]).values
343
+ pre_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
344
+ self .pre_impact , hdi_prob = hdi_prob
345
+ ).set_index (pre_data .index )
346
+ post_data [[impact_lower_col , impact_upper_col ]] = get_hdi_to_df (
347
+ self .post_impact , hdi_prob = hdi_prob
348
+ ).set_index (post_data .index )
349
+
350
+ self .plot_data = pd .concat ([pre_data , post_data ])
351
+
352
+ return self .plot_data
353
+ else :
354
+ raise ValueError ("Unsupported model type" )
355
+
356
+ def get_plot_data_ols (self ) -> pd .DataFrame :
357
+ """
358
+ Recover the data of a PrePostFit experiment along with the prediction and causal impact information.
359
+ """
360
+ pre_data = self .datapre .copy ()
361
+ post_data = self .datapost .copy ()
362
+ pre_data ["prediction" ] = self .pre_pred
363
+ post_data ["prediction" ] = self .post_pred
364
+ pre_data ["impact" ] = self .pre_impact
365
+ post_data ["impact" ] = self .post_impact
366
+ self .plot_data = pd .concat ([pre_data , post_data ])
367
+
368
+ return self .plot_data
369
+
306
370
307
371
class InterruptedTimeSeries (PrePostFit ):
308
372
"""
@@ -382,7 +446,7 @@ class SyntheticControl(PrePostFit):
382
446
supports_ols = True
383
447
supports_bayes = True
384
448
385
- def bayesian_plot (self , * args , ** kwargs ) -> tuple [plt .Figure , List [plt .Axes ]]:
449
+ def _bayesian_plot (self , * args , ** kwargs ) -> tuple [plt .Figure , List [plt .Axes ]]:
386
450
"""
387
451
Plot the results
388
452
@@ -393,7 +457,7 @@ def bayesian_plot(self, *args, **kwargs) -> tuple[plt.Figure, List[plt.Axes]]:
393
457
Whether to plot the control units as well. Defaults to False.
394
458
"""
395
459
# call the super class method
396
- fig , ax = super ().bayesian_plot (* args , ** kwargs )
460
+ fig , ax = super ()._bayesian_plot (* args , ** kwargs )
397
461
398
462
# additional plotting functionality for the synthetic control experiment
399
463
plot_predictors = kwargs .get ("plot_predictors" , False )
0 commit comments