Skip to content

Commit f08b02d

Browse files
committed
[IV 212] adding an input validation test
1 parent 7a28dab commit f08b02d

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

causalpy/pymc_experiments.py

+15
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,7 @@ def __init__(
899899
self.formula = formula
900900
self.instruments_formula = instruments_formula
901901
self.model = model
902+
self._input_validation()
902903

903904
y, X = dmatrices(formula, self.data)
904905
self._y_design_info = y.design_info
@@ -954,3 +955,17 @@ def get_naive_OLS_fit(self):
954955
beta_params.insert(0, ols_reg.intercept_[0])
955956
self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
956957
self.ols_reg = ols_reg
958+
959+
def _input_validation(self):
960+
"""Validate the input data and model formula for correctness"""
961+
treatment = self.instruments_formula.split("~")[0]
962+
test = treatment.strip() in self.instruments_data.columns
963+
test = test & (treatment.strip() in self.data.columns)
964+
if not test:
965+
raise DataException(
966+
f"""
967+
The treatment variable:
968+
{treatment} must appear in the instrument_data to be used
969+
as an outcome variable and in the data object to be used as a covariate.
970+
"""
971+
)

causalpy/tests/test_input_validation.py

+18
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,21 @@ def test_rd_validation_treated_is_dummy():
202202
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
203203
treatment_threshold=0.5,
204204
)
205+
206+
207+
def test_iv_treatment_var_is_present():
208+
data = pd.DataFrame({"x": [1, 2, 3], "y": [2, 4, 5]})
209+
instruments_formula = "risk ~ 1 + logmort0"
210+
formula = "loggdp ~ 1 + risk"
211+
instruments_data = pd.DataFrame({"z": [1, 3, 4], "w": [2, 3, 4]})
212+
213+
with pytest.raises(DataException):
214+
_ = cp.pymc_experiments.InstrumentalVariable(
215+
instruments_data=instruments_data,
216+
data=data,
217+
instruments_formula=instruments_formula,
218+
formula=formula,
219+
model=cp.pymc_models.InstrumentalVariableRegression(
220+
sample_kwargs=sample_kwargs
221+
),
222+
)

docs/source/_static/interrogate_badge.svg

+3-3
Loading

0 commit comments

Comments
 (0)