Skip to content

Commit 2b2cbdf

Browse files
committed
Removing the overriding of fit and calculate_impact, adding a test and fixing a bug
1 parent 4aef14b commit 2b2cbdf

File tree

6 files changed

+128
-184
lines changed

6 files changed

+128
-184
lines changed

causalpy/pymc_models.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -699,25 +699,6 @@ def build_model(self, X, y, coords):
699699
# Likelihodd of the base time series and the intervention's effect
700700
pm.Normal("y_ts", mu=mu_ts, sigma=sigma, observed=y, dims="obs_ind")
701701

702-
def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
703-
"""Draw samples from posterior, prior predictive, and posterior predictive
704-
distributions, placing them in the model's idata attribute.
705-
"""
706-
707-
# Ensure random_seed is used in sample_prior_predictive() and
708-
# sample_posterior_predictive() if provided in sample_kwargs.
709-
random_seed = self.sample_kwargs.get("random_seed", None)
710-
self.build_model(X, y, coords)
711-
with self:
712-
self.idata = pm.sample(**self.sample_kwargs)
713-
self.idata.extend(pm.sample_prior_predictive(random_seed=random_seed))
714-
self.idata.extend(
715-
pm.sample_posterior_predictive(
716-
self.idata, progressbar=False, random_seed=random_seed
717-
)
718-
)
719-
return self.idata
720-
721702
def predict(self, X):
722703
"""
723704
Predict data given input data `X`
@@ -731,19 +712,20 @@ def predict(self, X):
731712
random_seed = self.sample_kwargs.get("random_seed", None)
732713
self._data_setter(X)
733714
with self: # sample with new input data
734-
post_pred = pm.sample_posterior_predictive(
715+
pp = pm.sample_posterior_predictive(
735716
self.idata,
736717
var_names=["y_hat", "y_ts", "mu", "mu_ts", "mu_in"],
737718
progressbar=False,
738719
random_seed=random_seed,
739720
)
740-
return post_pred
741721

742-
def calculate_impact(
743-
self, y_true: xr.DataArray, y_pred: az.InferenceData
744-
) -> xr.DataArray:
745-
impact = y_true.data - y_pred["posterior_predictive"]["y_hat"]
746-
return impact.transpose(..., "obs_ind")
722+
# TODO: This is a bit of a hack. Maybe it could be done properly in _data_setter?
723+
if isinstance(X, xr.DataArray):
724+
pp["posterior_predictive"] = pp["posterior_predictive"].assign_coords(
725+
obs_ind=X.obs_ind
726+
)
727+
728+
return pp
747729

748730
def _data_setter(self, X) -> None:
749731
"""
@@ -770,7 +752,7 @@ def score(self, X, y) -> pd.Series:
770752
mu_ts = self.predict(X)
771753
mu_ts = az.extract(mu_ts, group="posterior_predictive", var_names="mu_ts").T
772754
# Note: First argument must be a 1D array
773-
return r2_score(y.data, mu_ts)
755+
return r2_score(y.data, mu_ts.data)
774756

775757
def set_time_range(self, time_range, data):
776758
"""

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,63 @@ def test_its():
402402
)
403403

404404

405+
@pytest.mark.integration
406+
def test_its_no_treatment_time():
407+
"""
408+
Test Interrupted Time-Series experiment on COVID data with an unknown treatment time.
409+
410+
Loads data and checks:
411+
1. data is a dataframe
412+
2. causalpy.InterruptedtimeSeries returns correct type
413+
3. the correct number of MCMC chains exists in the posterior inference data
414+
4. the correct number of MCMC draws exists in the posterior inference data
415+
5. the method get_plot_data returns a DataFrame with expected columns
416+
"""
417+
418+
df = (
419+
cp.load_data("covid")
420+
.assign(date=lambda x: pd.to_datetime(x["date"]))
421+
.set_index("date")
422+
)
423+
result = cp.InterruptedTimeSeries(
424+
df,
425+
None,
426+
formula="standardize(deaths) ~ 0 + t + C(month) + standardize(temp)", # noqa E501
427+
model=cp.pymc_models.InterventionTimeEstimator(
428+
time_variable_name="t",
429+
treatment_type_effect={"impulse": []},
430+
sample_kwargs=sample_kwargs,
431+
),
432+
)
433+
assert isinstance(df, pd.DataFrame)
434+
assert isinstance(result, cp.InterruptedTimeSeries)
435+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
436+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
437+
result.summary()
438+
fig, ax = result.plot()
439+
assert isinstance(fig, plt.Figure)
440+
# For multi-panel plots, ax should be an array of axes
441+
assert isinstance(ax, np.ndarray) and all(
442+
isinstance(item, plt.Axes) for item in ax
443+
), "ax must be a numpy.ndarray of plt.Axes"
444+
# Test get_plot_data with default parameters
445+
plot_data = result.get_plot_data()
446+
assert isinstance(plot_data, pd.DataFrame), (
447+
"The returned object is not a pandas DataFrame"
448+
)
449+
expected_columns = [
450+
"prediction",
451+
"pred_hdi_lower_94",
452+
"pred_hdi_upper_94",
453+
"impact",
454+
"impact_hdi_lower_94",
455+
"impact_hdi_upper_94",
456+
]
457+
assert set(expected_columns).issubset(set(plot_data.columns)), (
458+
f"DataFrame is missing expected columns {expected_columns}"
459+
)
460+
461+
405462
@pytest.mark.integration
406463
def test_its_covid():
407464
"""

docs/source/_static/classes.png

-129 KB
Loading

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/_static/packages.png

-27.4 KB
Loading

docs/source/notebooks/its_no_treatment_time.ipynb

Lines changed: 59 additions & 154 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)