Skip to content

Commit 04584d0

Browse files
committed
[IV 212] adding user warning and tidying params
Signed-off-by: Nathaniel <[email protected]>
1 parent f08b02d commit 04584d0

File tree

2 files changed

+134
-95
lines changed

2 files changed

+134
-95
lines changed

causalpy/pymc_experiments.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Union
23

34
import arviz as az
@@ -864,22 +865,28 @@ class InstrumentalVariable(ExperimentalDesign):
864865
A class to analyse instrumental variable style experiments.
865866
866867
:param instruments_data: A pandas dataframe of instruments
867-
for our treatment variable
868+
for our treatment variable. Should contain
869+
instruments Z, and treatment t
868870
:param data: A pandas dataframe of covariates for fitting
869-
the focal regression of interest
871+
the focal regression of interest. Should contain covariates X
872+
including treatment t and outcome y
870873
:param instruments_formula: A statistical model formula for
871874
the instrumental stage regression
872-
:param formula: A statistical model formula for the focal regression
875+
e.g. t ~ 1 + z1 + z2 + z3
876+
:param formula: A statistical model formula for the \n
877+
focal regression e.g. y ~ 1 + t + x1 + x2 + x3
873878
:param model: A PyMC model
874879
:param priors: An optional dictionary of priors for the
875880
mus and sigmas of both regressions. If priors are not
876881
specified we will substitue MLE estimates for the beta
877-
coefficients
882+
coefficients. Greater control can be achieved
883+
by specifying the priors directly e.g. priors = {
884+
"mus": [0, 0],
885+
"sigmas": [1, 1],
886+
"eta": 2,
887+
"lkj_sd": 2,
888+
}
878889
879-
.. note::
880-
881-
There is no pre/post intervention data distinction for the instrumental variable
882-
design, we fit all the data available.
883890
"""
884891

885892
def __init__(
@@ -969,3 +976,10 @@ def _input_validation(self):
969976
as an outcome variable and in the data object to be used as a covariate.
970977
"""
971978
)
979+
check_binary = len(self.data[treatment.strip()].unique()) > 2
980+
if check_binary:
981+
warnings.warn(
982+
"""Warning. The treatment variable is not Binary. \n
983+
This is not necessarily a problem but it complicates \n
984+
the interpretation of the model coefficients."""
985+
)

docs/source/notebooks/iv_pymc.ipynb

+112-87
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)