Skip to content

Commit 90cf94b

Browse files
authored
Merge pull request #224 from pymc-labs/rd-updates
Add an `epsilon` parameter to `RegressionDiscontinuity` classes
2 parents 21944bd + 694caab commit 90cf94b

File tree

4 files changed

+45
-21
lines changed

4 files changed

+45
-21
lines changed

causalpy/pymc_experiments.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -543,18 +543,19 @@ class RegressionDiscontinuity(ExperimentalDesign):
543543
"""
544544
A class to analyse regression discontinuity experiments.
545545
546-
:param data: A pandas dataframe
547-
:param formula: A statistical model formula
548-
:param treatment_threshold: A scalar threshold value at which the treatment
549-
is applied
550-
:param model: A PyMC model
551-
:param running_variable_name: The name of the predictor variable that the treatment
552-
threshold is based upon
553-
554-
.. note::
555-
556-
There is no pre/post intervention data distinction for the regression
557-
discontinuity design, we fit all the data available.
546+
:param data:
547+
A pandas dataframe
548+
:param formula:
549+
A statistical model formula
550+
:param treatment_threshold:
551+
A scalar threshold value at which the treatment is applied
552+
:param model:
553+
A PyMC model
554+
:param running_variable_name:
555+
The name of the predictor variable that the treatment threshold is based upon
556+
:param epsilon:
557+
A small scalar value which determines how far above and below the treatment
558+
threshold to evaluate the causal impact.
558559
"""
559560

560561
def __init__(
@@ -564,6 +565,7 @@ def __init__(
564565
treatment_threshold: float,
565566
model=None,
566567
running_variable_name: str = "x",
568+
epsilon: float = 0.001,
567569
**kwargs,
568570
):
569571
super().__init__(model=model, **kwargs)
@@ -572,6 +574,7 @@ def __init__(
572574
self.formula = formula
573575
self.running_variable_name = running_variable_name
574576
self.treatment_threshold = treatment_threshold
577+
self.epsilon = epsilon
575578
self._input_validation()
576579

577580
y, X = dmatrices(formula, self.data)
@@ -609,7 +612,10 @@ def __init__(
609612
self.x_discon = pd.DataFrame(
610613
{
611614
self.running_variable_name: np.array(
612-
[self.treatment_threshold - 0.001, self.treatment_threshold + 0.001]
615+
[
616+
self.treatment_threshold - self.epsilon,
617+
self.treatment_threshold + self.epsilon,
618+
]
613619
),
614620
"treated": np.array([0, 1]),
615621
}

causalpy/skl_experiments.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,21 @@ def plot(self):
346346

347347
class RegressionDiscontinuity(ExperimentalDesign):
348348
"""
349-
Analyse data from regression discontinuity experiments.
350-
351-
.. note::
352-
353-
There is no pre/post intervention data distinction for the regression
354-
discontinuity design, we fit all the data available.
355-
349+
A class to analyse regression discontinuity experiments.
350+
351+
:param data:
352+
A pandas dataframe
353+
:param formula:
354+
A statistical model formula
355+
:param treatment_threshold:
356+
A scalar threshold value at which the treatment is applied
357+
:param model:
358+
A sci-kit learn model object
359+
:param running_variable_name:
360+
The name of the predictor variable that the treatment threshold is based upon
361+
:param epsilon:
362+
A small scalar value which determines how far above and below the treatment
363+
threshold to evaluate the causal impact.
356364
"""
357365

358366
def __init__(
@@ -362,13 +370,15 @@ def __init__(
362370
treatment_threshold,
363371
model=None,
364372
running_variable_name="x",
373+
epsilon: float = 0.001,
365374
**kwargs,
366375
):
367376
super().__init__(model=model, **kwargs)
368377
self.data = data
369378
self.formula = formula
370379
self.running_variable_name = running_variable_name
371380
self.treatment_threshold = treatment_threshold
381+
self.epsilon = epsilon
372382
y, X = dmatrices(formula, self.data)
373383
self._y_design_info = y.design_info
374384
self._x_design_info = X.design_info
@@ -404,7 +414,10 @@ def __init__(
404414
self.x_discon = pd.DataFrame(
405415
{
406416
self.running_variable_name: np.array(
407-
[self.treatment_threshold - 0.001, self.treatment_threshold + 0.001]
417+
[
418+
self.treatment_threshold - self.epsilon,
419+
self.treatment_threshold + self.epsilon,
420+
]
408421
),
409422
"treated": np.array([0, 1]),
410423
}

causalpy/tests/test_integration_pymc_examples.py

+1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def test_rd():
112112
formula="y ~ 1 + bs(x, df=6) + treated",
113113
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
114114
treatment_threshold=0.5,
115+
epsilon=0.001,
115116
)
116117
assert isinstance(df, pd.DataFrame)
117118
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)

causalpy/tests/test_integration_skl_examples.py

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_rd_drinking():
3636
running_variable_name="age",
3737
model=LinearRegression(),
3838
treatment_threshold=21,
39+
epsilon=0.001,
3940
)
4041
assert isinstance(df, pd.DataFrame)
4142
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
@@ -81,6 +82,7 @@ def test_rd_linear_main_effects():
8182
formula="y ~ 1 + x + treated",
8283
model=LinearRegression(),
8384
treatment_threshold=0.5,
85+
epsilon=0.001,
8486
)
8587
assert isinstance(data, pd.DataFrame)
8688
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
@@ -94,6 +96,7 @@ def test_rd_linear_with_interaction():
9496
formula="y ~ 1 + x + treated + x:treated",
9597
model=LinearRegression(),
9698
treatment_threshold=0.5,
99+
epsilon=0.001,
97100
)
98101
assert isinstance(data, pd.DataFrame)
99102
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
@@ -108,6 +111,7 @@ def test_rd_linear_with_gaussian_process():
108111
formula="y ~ 1 + x + treated",
109112
model=GaussianProcessRegressor(kernel=kernel),
110113
treatment_threshold=0.5,
114+
epsilon=0.001,
111115
)
112116
assert isinstance(data, pd.DataFrame)
113117
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)

0 commit comments

Comments
 (0)