Skip to content

Commit efa3f6d

Browse files
authored
Merge branch 'main' into feature_instrumental_variables
2 parents 04584d0 + 31e0039 commit efa3f6d

8 files changed

+937
-427
lines changed

causalpy/pymc_experiments.py

+48-22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Union
2+
from typing import Optional, Union
33

44
import arviz as az
55
import matplotlib.pyplot as plt
@@ -543,20 +543,24 @@ def summary(self):
543543

544544
class RegressionDiscontinuity(ExperimentalDesign):
545545
"""
546-
A class to analyse regression discontinuity experiments.
547-
548-
:param data: A pandas dataframe
549-
:param formula: A statistical model formula
550-
:param treatment_threshold: A scalar threshold value at which the treatment
551-
is applied
552-
:param model: A PyMC model
553-
:param running_variable_name: The name of the predictor variable that the treatment
554-
threshold is based upon
555-
556-
.. note::
557-
558-
There is no pre/post intervention data distinction for the regression
559-
discontinuity design, we fit all the data available.
546+
A class to analyse sharp regression discontinuity experiments.
547+
548+
:param data:
549+
A pandas dataframe
550+
:param formula:
551+
A statistical model formula
552+
:param treatment_threshold:
553+
A scalar threshold value at which the treatment is applied
554+
:param model:
555+
A PyMC model
556+
:param running_variable_name:
557+
The name of the predictor variable that the treatment threshold is based upon
558+
:param epsilon:
559+
A small scalar value which determines how far above and below the treatment
560+
threshold to evaluate the causal impact.
561+
:param bandwidth:
562+
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
563+
the model.
560564
"""
561565

562566
def __init__(
@@ -566,6 +570,8 @@ def __init__(
566570
treatment_threshold: float,
567571
model=None,
568572
running_variable_name: str = "x",
573+
epsilon: float = 0.001,
574+
bandwidth: Optional[float] = None,
569575
**kwargs,
570576
):
571577
super().__init__(model=model, **kwargs)
@@ -574,9 +580,23 @@ def __init__(
574580
self.formula = formula
575581
self.running_variable_name = running_variable_name
576582
self.treatment_threshold = treatment_threshold
583+
self.epsilon = epsilon
584+
self.bandwidth = bandwidth
577585
self._input_validation()
578586

579-
y, X = dmatrices(formula, self.data)
587+
if self.bandwidth is not None:
588+
fmin = self.treatment_threshold - self.bandwidth
589+
fmax = self.treatment_threshold + self.bandwidth
590+
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
591+
if len(filtered_data) <= 10:
592+
warnings.warn(
593+
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
594+
UserWarning,
595+
)
596+
y, X = dmatrices(formula, filtered_data)
597+
else:
598+
y, X = dmatrices(formula, self.data)
599+
580600
self._y_design_info = y.design_info
581601
self._x_design_info = X.design_info
582602
self.labels = X.design_info.column_names
@@ -593,11 +613,14 @@ def __init__(
593613
self.score = self.model.score(X=self.X, y=self.y)
594614

595615
# get the model predictions of the observed data
596-
xi = np.linspace(
597-
np.min(self.data[self.running_variable_name]),
598-
np.max(self.data[self.running_variable_name]),
599-
200,
600-
)
616+
if self.bandwidth is not None:
617+
xi = np.linspace(fmin, fmax, 200)
618+
else:
619+
xi = np.linspace(
620+
np.min(self.data[self.running_variable_name]),
621+
np.max(self.data[self.running_variable_name]),
622+
200,
623+
)
601624
self.x_pred = pd.DataFrame(
602625
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
603626
)
@@ -611,7 +634,10 @@ def __init__(
611634
self.x_discon = pd.DataFrame(
612635
{
613636
self.running_variable_name: np.array(
614-
[self.treatment_threshold - 0.001, self.treatment_threshold + 0.001]
637+
[
638+
self.treatment_threshold - self.epsilon,
639+
self.treatment_threshold + self.epsilon,
640+
]
615641
),
616642
"treated": np.array([0, 1]),
617643
}

causalpy/skl_experiments.py

+51-14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import warnings
2+
from typing import Optional
3+
14
import matplotlib.pyplot as plt
25
import numpy as np
36
import pandas as pd
@@ -346,13 +349,24 @@ def plot(self):
346349

347350
class RegressionDiscontinuity(ExperimentalDesign):
348351
"""
349-
Analyse data from regression discontinuity experiments.
350-
351-
.. note::
352-
353-
There is no pre/post intervention data distinction for the regression
354-
discontinuity design, we fit all the data available.
355-
352+
A class to analyse sharp regression discontinuity experiments.
353+
354+
:param data:
355+
A pandas dataframe
356+
:param formula:
357+
A statistical model formula
358+
:param treatment_threshold:
359+
A scalar threshold value at which the treatment is applied
360+
:param model:
361+
A sci-kit learn model object
362+
:param running_variable_name:
363+
The name of the predictor variable that the treatment threshold is based upon
364+
:param epsilon:
365+
A small scalar value which determines how far above and below the treatment
366+
threshold to evaluate the causal impact.
367+
:param bandwidth:
368+
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
369+
the model.
356370
"""
357371

358372
def __init__(
@@ -362,14 +376,31 @@ def __init__(
362376
treatment_threshold,
363377
model=None,
364378
running_variable_name="x",
379+
epsilon: float = 0.001,
380+
bandwidth: Optional[float] = None,
365381
**kwargs,
366382
):
367383
super().__init__(model=model, **kwargs)
368384
self.data = data
369385
self.formula = formula
370386
self.running_variable_name = running_variable_name
371387
self.treatment_threshold = treatment_threshold
372-
y, X = dmatrices(formula, self.data)
388+
self.bandwidth = bandwidth
389+
self.epsilon = epsilon
390+
391+
if self.bandwidth is not None:
392+
fmin = self.treatment_threshold - self.bandwidth
393+
fmax = self.treatment_threshold + self.bandwidth
394+
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
395+
if len(filtered_data) <= 10:
396+
warnings.warn(
397+
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
398+
UserWarning,
399+
)
400+
y, X = dmatrices(formula, filtered_data)
401+
else:
402+
y, X = dmatrices(formula, self.data)
403+
373404
self._y_design_info = y.design_info
374405
self._x_design_info = X.design_info
375406
self.labels = X.design_info.column_names
@@ -386,11 +417,14 @@ def __init__(
386417
self.score = self.model.score(X=self.X, y=self.y)
387418

388419
# get the model predictions of the observed data
389-
xi = np.linspace(
390-
np.min(self.data[self.running_variable_name]),
391-
np.max(self.data[self.running_variable_name]),
392-
1000,
393-
)
420+
if self.bandwidth is not None:
421+
xi = np.linspace(fmin, fmax, 200)
422+
else:
423+
xi = np.linspace(
424+
np.min(self.data[self.running_variable_name]),
425+
np.max(self.data[self.running_variable_name]),
426+
200,
427+
)
394428
self.x_pred = pd.DataFrame(
395429
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
396430
)
@@ -404,7 +438,10 @@ def __init__(
404438
self.x_discon = pd.DataFrame(
405439
{
406440
self.running_variable_name: np.array(
407-
[self.treatment_threshold - 0.001, self.treatment_threshold + 0.001]
441+
[
442+
self.treatment_threshold - self.epsilon,
443+
self.treatment_threshold + self.epsilon,
444+
]
408445
),
409446
"treated": np.array([0, 1]),
410447
}

causalpy/tests/test_integration_pymc_examples.py

+18
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,24 @@ def test_rd():
112112
formula="y ~ 1 + bs(x, df=6) + treated",
113113
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
114114
treatment_threshold=0.5,
115+
epsilon=0.001,
116+
)
117+
assert isinstance(df, pd.DataFrame)
118+
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
119+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
120+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
121+
122+
123+
@pytest.mark.integration
124+
def test_rd_bandwidth():
125+
df = cp.load_data("rd")
126+
result = cp.pymc_experiments.RegressionDiscontinuity(
127+
df,
128+
formula="y ~ 1 + x + treated + x:treated",
129+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
130+
treatment_threshold=0.5,
131+
epsilon=0.001,
132+
bandwidth=0.3,
115133
)
116134
assert isinstance(df, pd.DataFrame)
117135
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)

causalpy/tests/test_integration_skl_examples.py

+19
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_rd_drinking():
3636
running_variable_name="age",
3737
model=LinearRegression(),
3838
treatment_threshold=21,
39+
epsilon=0.001,
3940
)
4041
assert isinstance(df, pd.DataFrame)
4142
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
@@ -81,6 +82,22 @@ def test_rd_linear_main_effects():
8182
formula="y ~ 1 + x + treated",
8283
model=LinearRegression(),
8384
treatment_threshold=0.5,
85+
epsilon=0.001,
86+
)
87+
assert isinstance(data, pd.DataFrame)
88+
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
89+
90+
91+
@pytest.mark.integration
92+
def test_rd_linear_main_effects_bandwidth():
93+
data = cp.load_data("rd")
94+
result = cp.skl_experiments.RegressionDiscontinuity(
95+
data,
96+
formula="y ~ 1 + x + treated",
97+
model=LinearRegression(),
98+
treatment_threshold=0.5,
99+
epsilon=0.001,
100+
bandwidth=0.3,
84101
)
85102
assert isinstance(data, pd.DataFrame)
86103
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
@@ -94,6 +111,7 @@ def test_rd_linear_with_interaction():
94111
formula="y ~ 1 + x + treated + x:treated",
95112
model=LinearRegression(),
96113
treatment_threshold=0.5,
114+
epsilon=0.001,
97115
)
98116
assert isinstance(data, pd.DataFrame)
99117
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
@@ -108,6 +126,7 @@ def test_rd_linear_with_gaussian_process():
108126
formula="y ~ 1 + x + treated",
109127
model=GaussianProcessRegressor(kernel=kernel),
110128
treatment_threshold=0.5,
129+
epsilon=0.001,
111130
)
112131
assert isinstance(data, pd.DataFrame)
113132
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)

docs/source/_static/interrogate_badge.svg

+3-3
Loading

0 commit comments

Comments
 (0)