Skip to content

Commit ca90775

Browse files
authored
Merge pull request #200 from pymc-labs/its-class
Add `InterruptedTimeSeries` class
2 parents b61295e + b7f023a commit ca90775

File tree

7 files changed

+77
-66
lines changed

7 files changed

+77
-66
lines changed

causalpy/pymc_experiments.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _input_validation(self, data, treatment_time):
134134
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
135135
)
136136

137-
def plot(self):
137+
def plot(self, counterfactual_label="Counterfactual", **kwargs):
138138
"""Plot the results"""
139139
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
140140

@@ -161,7 +161,7 @@ def plot(self):
161161
plot_hdi_kwargs={"color": "C1"},
162162
)
163163
handles.append((h_line, h_patch))
164-
labels.append("Synthetic control")
164+
labels.append(counterfactual_label)
165165

166166
ax[0].plot(self.datapost.index, self.post_y, "k.")
167167
# Shaded causal effect
@@ -243,14 +243,20 @@ def summary(self):
243243
self.print_coefficients()
244244

245245

246+
class InterruptedTimeSeries(PrePostFit):
247+
"""Interrupted time series analysis"""
248+
249+
expt_type = "Interrupted Time Series"
250+
251+
246252
class SyntheticControl(PrePostFit):
247253
"""A wrapper around the PrePostFit class"""
248254

249255
expt_type = "Synthetic Control"
250256

251-
def plot(self, plot_predictors=False):
257+
def plot(self, plot_predictors=False, **kwargs):
252258
"""Plot the results"""
253-
fig, ax = super().plot()
259+
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
254260
if plot_predictors:
255261
# plot control units as well
256262
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)

causalpy/skl_experiments.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
# cumulative impact post
7575
self.post_impact_cumulative = np.cumsum(self.post_impact)
7676

77-
def plot(self):
77+
def plot(self, counterfactual_label="Counterfactual", **kwargs):
7878
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
7979

8080
ax[0].plot(self.datapre.index, self.pre_y, "k.")
@@ -84,7 +84,7 @@ def plot(self):
8484
ax[0].plot(
8585
self.datapost.index,
8686
self.post_pred,
87-
label="counterfactual",
87+
label=counterfactual_label,
8888
ls=":",
8989
c="k",
9090
)
@@ -95,7 +95,7 @@ def plot(self):
9595
self.datapost.index,
9696
self.post_impact,
9797
"k.",
98-
label="counterfactual",
98+
label=counterfactual_label,
9999
)
100100
ax[1].axhline(y=0, c="k")
101101
ax[1].set(title="Causal Impact")
@@ -151,12 +151,18 @@ def plot_coeffs(self):
151151
)
152152

153153

154+
class InterruptedTimeSeries(PrePostFit):
155+
"""Interrupted time series analysis"""
156+
157+
expt_type = "Interrupted Time Series"
158+
159+
154160
class SyntheticControl(PrePostFit):
155161
"""A wrapper around the PrePostFit class"""
156162

157-
def plot(self, plot_predictors=False):
163+
def plot(self, plot_predictors=False, **kwargs):
158164
"""Plot the results"""
159-
fig, ax = super().plot()
165+
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
160166
if plot_predictors:
161167
# plot control units as well
162168
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,14 @@ def test_its_covid():
167167
.set_index("date")
168168
)
169169
treatment_time = pd.to_datetime("2020-01-01")
170-
result = cp.pymc_experiments.SyntheticControl(
170+
result = cp.pymc_experiments.InterruptedTimeSeries(
171171
df,
172172
treatment_time,
173173
formula="standardize(deaths) ~ 0 + standardize(t) + C(month) + standardize(temp)", # noqa E501
174174
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
175175
)
176176
assert isinstance(df, pd.DataFrame)
177-
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
177+
assert isinstance(result, cp.pymc_experiments.InterruptedTimeSeries)
178178
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
179179
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
180180

docs/source/_static/classes.png

10.3 KB
Loading

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/notebooks/its_covid.ipynb

Lines changed: 43 additions & 49 deletions
Large diffs are not rendered by default.

docs/source/notebooks/sc_skl.ipynb

Lines changed: 8 additions & 3 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)