Skip to content

Commit 94895b4

Browse files
authored
Merge pull request #226 from pymc-labs/rd-bandwidth
Regression Discontinuity: add ability to specify bandwidth
2 parents 90cf94b + cf64ac2 commit 94895b4

File tree

7 files changed

+722
-246
lines changed

7 files changed

+722
-246
lines changed

causalpy/pymc_experiments.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Union
1+
import warnings
2+
from typing import Optional, Union
23

34
import arviz as az
45
import matplotlib.pyplot as plt
@@ -541,7 +542,7 @@ def summary(self):
541542

542543
class RegressionDiscontinuity(ExperimentalDesign):
543544
"""
544-
A class to analyse regression discontinuity experiments.
545+
A class to analyse sharp regression discontinuity experiments.
545546
546547
:param data:
547548
A pandas dataframe
@@ -556,6 +557,9 @@ class RegressionDiscontinuity(ExperimentalDesign):
556557
:param epsilon:
557558
A small scalar value which determines how far above and below the treatment
558559
threshold to evaluate the causal impact.
560+
:param bandwidth:
561+
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
562+
the model.
559563
"""
560564

561565
def __init__(
@@ -566,6 +570,7 @@ def __init__(
566570
model=None,
567571
running_variable_name: str = "x",
568572
epsilon: float = 0.001,
573+
bandwidth: Optional[float] = None,
569574
**kwargs,
570575
):
571576
super().__init__(model=model, **kwargs)
@@ -575,9 +580,22 @@ def __init__(
575580
self.running_variable_name = running_variable_name
576581
self.treatment_threshold = treatment_threshold
577582
self.epsilon = epsilon
583+
self.bandwidth = bandwidth
578584
self._input_validation()
579585

580-
y, X = dmatrices(formula, self.data)
586+
if self.bandwidth is not None:
587+
fmin = self.treatment_threshold - self.bandwidth
588+
fmax = self.treatment_threshold + self.bandwidth
589+
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
590+
if len(filtered_data) <= 10:
591+
warnings.warn(
592+
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
593+
UserWarning,
594+
)
595+
y, X = dmatrices(formula, filtered_data)
596+
else:
597+
y, X = dmatrices(formula, self.data)
598+
581599
self._y_design_info = y.design_info
582600
self._x_design_info = X.design_info
583601
self.labels = X.design_info.column_names
@@ -594,11 +612,14 @@ def __init__(
594612
self.score = self.model.score(X=self.X, y=self.y)
595613

596614
# get the model predictions of the observed data
597-
xi = np.linspace(
598-
np.min(self.data[self.running_variable_name]),
599-
np.max(self.data[self.running_variable_name]),
600-
200,
601-
)
615+
if self.bandwidth is not None:
616+
xi = np.linspace(fmin, fmax, 200)
617+
else:
618+
xi = np.linspace(
619+
np.min(self.data[self.running_variable_name]),
620+
np.max(self.data[self.running_variable_name]),
621+
200,
622+
)
602623
self.x_pred = pd.DataFrame(
603624
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
604625
)

causalpy/skl_experiments.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import warnings
2+
from typing import Optional
3+
14
import matplotlib.pyplot as plt
25
import numpy as np
36
import pandas as pd
@@ -346,7 +349,7 @@ def plot(self):
346349

347350
class RegressionDiscontinuity(ExperimentalDesign):
348351
"""
349-
A class to analyse regression discontinuity experiments.
352+
A class to analyse sharp regression discontinuity experiments.
350353
351354
:param data:
352355
A pandas dataframe
@@ -361,6 +364,9 @@ class RegressionDiscontinuity(ExperimentalDesign):
361364
:param epsilon:
362365
A small scalar value which determines how far above and below the treatment
363366
threshold to evaluate the causal impact.
367+
:param bandwidth:
368+
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
369+
the model.
364370
"""
365371

366372
def __init__(
@@ -371,15 +377,30 @@ def __init__(
371377
model=None,
372378
running_variable_name="x",
373379
epsilon: float = 0.001,
380+
bandwidth: Optional[float] = None,
374381
**kwargs,
375382
):
376383
super().__init__(model=model, **kwargs)
377384
self.data = data
378385
self.formula = formula
379386
self.running_variable_name = running_variable_name
380387
self.treatment_threshold = treatment_threshold
388+
self.bandwidth = bandwidth
381389
self.epsilon = epsilon
382-
y, X = dmatrices(formula, self.data)
390+
391+
if self.bandwidth is not None:
392+
fmin = self.treatment_threshold - self.bandwidth
393+
fmax = self.treatment_threshold + self.bandwidth
394+
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
395+
if len(filtered_data) <= 10:
396+
warnings.warn(
397+
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
398+
UserWarning,
399+
)
400+
y, X = dmatrices(formula, filtered_data)
401+
else:
402+
y, X = dmatrices(formula, self.data)
403+
383404
self._y_design_info = y.design_info
384405
self._x_design_info = X.design_info
385406
self.labels = X.design_info.column_names
@@ -396,11 +417,14 @@ def __init__(
396417
self.score = self.model.score(X=self.X, y=self.y)
397418

398419
# get the model predictions of the observed data
399-
xi = np.linspace(
400-
np.min(self.data[self.running_variable_name]),
401-
np.max(self.data[self.running_variable_name]),
402-
1000,
403-
)
420+
if self.bandwidth is not None:
421+
xi = np.linspace(fmin, fmax, 200)
422+
else:
423+
xi = np.linspace(
424+
np.min(self.data[self.running_variable_name]),
425+
np.max(self.data[self.running_variable_name]),
426+
200,
427+
)
404428
self.x_pred = pd.DataFrame(
405429
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
406430
)

causalpy/tests/test_integration_pymc_examples.py

+17
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,23 @@ def test_rd():
120120
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
121121

122122

123+
@pytest.mark.integration
124+
def test_rd_bandwidth():
125+
df = cp.load_data("rd")
126+
result = cp.pymc_experiments.RegressionDiscontinuity(
127+
df,
128+
formula="y ~ 1 + x + treated + x:treated",
129+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
130+
treatment_threshold=0.5,
131+
epsilon=0.001,
132+
bandwidth=0.3,
133+
)
134+
assert isinstance(df, pd.DataFrame)
135+
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
136+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
137+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
138+
139+
123140
@pytest.mark.integration
124141
def test_rd_drinking():
125142
df = (

causalpy/tests/test_integration_skl_examples.py

+15
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,21 @@ def test_rd_linear_main_effects():
8888
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
8989

9090

91+
@pytest.mark.integration
92+
def test_rd_linear_main_effects_bandwidth():
93+
data = cp.load_data("rd")
94+
result = cp.skl_experiments.RegressionDiscontinuity(
95+
data,
96+
formula="y ~ 1 + x + treated",
97+
model=LinearRegression(),
98+
treatment_threshold=0.5,
99+
epsilon=0.001,
100+
bandwidth=0.3,
101+
)
102+
assert isinstance(data, pd.DataFrame)
103+
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
104+
105+
91106
@pytest.mark.integration
92107
def test_rd_linear_with_interaction():
93108
data = cp.load_data("rd")

docs/source/_static/interrogate_badge.svg

+3-3
Loading

docs/source/notebooks/rd_pymc.ipynb

+538-218
Large diffs are not rendered by default.

docs/source/notebooks/rd_skl.ipynb

+89-10
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)