Skip to content

Commit 3070bc9

Browse files
authored
Merge pull request #213 from NathanielF/feature_instrumental_variables
Add Bayesian instrumental variable estimation
2 parents 31e0039 + cb68c78 commit 3070bc9

14 files changed

+1239
-3
lines changed

causalpy/data/AJR2001.csv

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
longname,shortnam,logmort0,risk,loggdp,campaign,source0,slave,latitude,neoeuro,asia,africa,other,edes1975,campaignsj,campaignsj2,mortnaval1,logmortnaval1,mortnaval2,logmortnaval2,mortjam,logmortjam,logmortcap250,logmortjam250,wandcafrica,malfal94,wacacontested,mortnaval2250,logmortnaval2250,mortnaval1250,logmortnaval1250
2+
Angola,AGO,5.6347895,5.3600001,7.77,1,0,0,0.1367,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
3+
Argentina,ARG,4.232656,6.3899999,9.1300001,1,0,0,0.37779999,0,0,0,0,90.0,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.232656,4.0342407,0,0.0,0,30.5,3.4177268,15.07,2.7127061
4+
Australia,AUS,2.1459312,9.3199997,9.8999996,0,0,0,0.30000001,1,0,0,1,99.0,0,1,8.5500002,2.1459312,8.5500002,2.1459312,8.5500002,2.1459312,2.1459312,2.1459312,0,0.0,0,8.5500002,2.1459312,8.5500002,2.1459312
5+
Burkina Faso,BFA,5.6347895,4.4499998,6.8499999,1,0,0,0.1444,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
6+
Bangladesh,BGD,4.2684379,5.1399999,6.8800001,1,1,0,0.2667,0,1,0,0,0.0,1,1,71.410004,4.2684379,71.410004,4.2684379,71.410004,4.2684379,4.2684379,4.2684379,0,0.12008,0,71.410004,4.2684379,71.410004,4.2684379
7+
Bahamas,BHS,4.4426513,7.5,9.29,0,0,0,0.2683,0,0,0,0,10.0,0,0,85.0,4.4426513,85.0,4.4426513,85.0,4.4426513,4.4426513,4.4426513,0,,0,85.0,4.4426513,85.0,4.4426513
8+
Bolivia,BOL,4.2626801,5.6399999,7.9299998,1,0,0,0.18889999,0,0,0,0,30.000002,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.00165,0,93.25,4.535284,,
9+
Brazil,BRA,4.2626801,7.9099998,8.7299995,1,0,0,0.1111,0,0,0,0,55.0,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.2626801,4.0342407,0,0.035999998,0,30.5,3.4177268,15.07,2.7127061
10+
Canada,CAN,2.7788193,9.7299995,9.9899998,0,1,0,0.66670001,1,0,0,0,98.0,0,0,16.1,2.7788193,16.1,2.7788193,16.1,2.7788193,2.7788193,2.7788193,0,0.0,0,16.1,2.7788193,16.1,2.7788193
11+
Chile,CHL,4.232656,7.8200002,9.3400002,1,0,0,0.33329999,0,0,0,0,50.0,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.232656,4.0342407,0,0.0,0,30.5,3.4177268,15.07,2.7127061
12+
Cote d'Ivoire,CIV,6.5042882,7.0,7.4400001,1,0,0,0.0889,0,0,1,0,0.0,1,1,668.0,6.5042882,668.0,6.5042882,668.0,6.5042882,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
13+
Cameroon,CMR,5.6347895,6.4499998,7.5,1,0,0,0.066699997,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
14+
Congo,COG,5.480639,4.6799998,7.4200001,0,1,1,0.0111,0,0,1,0,0.0,0,0,240.0,5.480639,240.0,5.480639,240.0,5.480639,5.480639,5.480639,1,0.94999999,0,240.0,5.480639,240.0,5.480639
15+
Colombia,COL,4.2626801,7.3200002,8.8100004,1,0,0,0.044399999,0,0,0,0,25.0,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.14637001,0,93.25,4.535284,,
16+
Costa Rica,CRI,4.3579903,7.0500002,8.79,1,0,0,0.1111,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0,0,93.25,4.535284,,
17+
Dominican Re,DOM,4.8675346,6.1799998,8.3599997,0,0,0,0.2111,0,0,0,0,25.0,0,0,130.0,4.8675346,130.0,4.8675346,130.0,4.8675346,4.8675346,4.8675346,0,0.0,0,130.0,4.8675346,130.0,4.8675346
18+
Algeria,DZA,4.3592696,6.5,8.3900003,1,1,0,0.31110001,0,0,1,0,0.0,1,1,78.199997,4.3592696,78.199997,4.3592696,78.199997,4.3592696,4.3592696,4.3592696,0,0.0,0,78.199997,4.3592696,78.199997,4.3592696
19+
Ecuador,ECU,4.2626801,6.5500002,8.4700003,1,0,0,0.0222,0,0,0,0,30.000002,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.11894999,0,93.25,4.535284,,
20+
Egypt,EGY,4.2165623,6.77,7.9499998,1,1,0,0.30000001,0,0,1,0,0.0,1,1,67.800003,4.2165623,67.800003,4.2165623,67.800003,4.2165623,4.2165623,4.2165623,0,0.0,0,67.800003,4.2165623,67.800003,4.2165623
21+
Ethiopia,ETH,3.2580965,5.73,6.1100001,1,1,0,0.0889,0,0,1,0,0.0,1,1,26.0,3.2580965,26.0,3.2580965,26.0,3.2580965,3.2580965,3.2580965,1,0.551,0,26.0,3.2580965,26.0,3.2580965
22+
Gabon,GAB,5.6347895,7.8200002,8.8999996,1,0,0,0.0111,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94050002,1,250.0,5.521461,250.0,5.521461
23+
Ghana,GHA,6.5042882,6.27,7.3699999,1,1,0,0.0889,0,0,1,0,0.0,1,1,668.0,6.5042882,668.0,6.5042882,668.0,6.5042882,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
24+
Guinea,GIN,6.1800165,6.5500002,7.4899998,1,0,0,0.1222,0,0,1,0,0.0,1,1,483.0,6.1800165,483.0,6.1800165,483.0,6.1800165,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
25+
Gambia,GMB,7.2930179,8.2700005,7.27,1,1,0,0.1476,0,0,1,0,0.0,1,1,1470.0,7.2930179,1470.0,7.2930179,1470.0,7.2930179,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
26+
Guatemala,GTM,4.2626801,5.1399999,8.29,1,0,0,0.17,0,0,0,0,20.0,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.0036000002,0,93.25,4.535284,,
27+
Guyana,GUY,3.4713452,5.8899999,7.9000001,0,0,0,0.055599999,0,0,0,0,2.0,0,0,32.18,3.4713452,32.18,3.4713452,32.18,3.4713452,3.4713452,3.4713452,0,0.49503002,0,32.18,3.4713452,32.18,3.4713452
28+
Hong Kong,HKG,2.7013612,8.1400003,10.05,0,0,0,0.24609999,0,1,0,0,0.0,1,1,14.9,2.7013612,14.9,2.7013612,14.9,2.7013612,2.7013612,2.7013612,0,0.0,0,14.9,2.7013612,14.9,2.7013612
29+
Honduras,HND,4.3579903,5.3200002,7.6900001,1,0,0,0.16670001,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.012,0,93.25,4.535284,,
30+
Haiti,HTI,4.8675346,3.73,7.1500001,0,0,0,0.2111,0,0,0,0,0.0,0,0,130.0,4.8675346,130.0,4.8675346,130.0,4.8675346,4.8675346,4.8675346,0,1.0,0,130.0,4.8675346,130.0,4.8675346
31+
Indonesia,IDN,5.1357985,7.5900002,7.3299999,1,1,0,0.055599999,0,1,0,0,0.0,1,1,170.0,5.1357985,170.0,5.1357985,170.0,5.1357985,5.1357985,5.1357985,0,0.17873,0,170.0,5.1357985,170.0,5.1357985
32+
India,IND,3.8842406,8.2700005,7.3299999,0,1,0,0.22220001,0,1,0,0,0.0,0,0,48.630001,3.8842406,48.630001,3.8842406,48.630001,3.8842406,3.8842406,3.8842406,0,0.23596001,0,48.630001,3.8842406,48.630001,3.8842406
33+
Jamaica,JAM,4.8675346,7.0900002,8.1899996,0,1,0,0.2017,0,0,0,0,10.0,0,1,130.0,4.8675346,130.0,4.8675346,130.0,4.8675346,4.8675346,4.8675346,0,0.0,0,130.0,4.8675346,130.0,4.8675346
34+
Kenya,KEN,4.9767337,6.0500002,7.0599999,0,1,1,0.0111,0,0,1,0,0.0,0,0,145.0,4.9767337,145.0,4.9767337,145.0,4.9767337,4.9767337,4.9767337,1,0.79799998,0,145.0,4.9767337,145.0,4.9767337
35+
Sri Lanka,LKA,4.2456341,6.0500002,7.73,0,1,0,0.077799998,0,1,0,0,0.0,0,1,69.800003,4.2456341,69.800003,4.2456341,69.800003,4.2456341,4.2456341,4.2456341,0,0.138,0,69.800003,4.2456341,69.800003,4.2456341
36+
Morocco,MAR,4.3592696,7.0900002,8.04,1,0,0,0.3556,0,0,1,0,1.0,1,1,78.199997,4.3592696,78.199997,4.3592696,78.199997,4.3592696,4.3592696,4.3592696,0,0.0,0,78.199997,4.3592696,78.199997,4.3592696
37+
Madagascar,MDG,6.2842088,4.4499998,6.8400002,1,1,0,0.22220001,0,0,1,0,0.0,1,1,536.03998,6.2842088,536.03998,6.2842088,536.03998,6.2842088,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
38+
Mexico,MEX,4.2626801,7.5,8.9399996,1,1,0,0.25560001,0,0,0,0,15.000001,1,1,71.0,4.2626801,71.0,4.2626801,71.0,4.2626801,4.2626801,4.2626801,0,0.00042,0,71.0,4.2626801,71.0,4.2626801
39+
Mali,MLI,7.986165,4.0,6.5700002,1,1,0,0.18889999,0,0,1,0,0.0,1,1,2940.0,7.986165,2940.0,7.986165,2940.0,7.986165,5.521461,5.521461,1,0.94050002,0,250.0,5.521461,250.0,5.521461
40+
Malta,MLT,2.7911651,7.23,9.4300003,0,1,0,0.3944,0,0,0,1,100.0,0,0,16.299999,2.7911651,16.299999,2.7911651,16.299999,2.7911651,2.7911651,2.7911651,0,,0,16.299999,2.7911651,16.299999,2.7911651
41+
Malaysia,MYS,2.8735647,7.9499998,8.8900003,0,1,0,0.025599999,0,1,0,0,0.0,0,1,17.700001,2.8735647,17.700001,2.8735647,17.700001,2.8735647,2.8735647,2.8735647,0,0.23331,0,17.700001,2.8735647,17.700001,2.8735647
42+
Niger,NER,5.9914646,5.0,6.73,1,0,0,0.1778,0,0,1,0,0.0,1,1,400.0,5.9914646,400.0,5.9914646,400.0,5.9914646,5.521461,5.521461,1,0.94050002,1,250.0,5.521461,250.0,5.521461
43+
Nigeria,NGA,7.6029005,5.5500002,6.8099999,1,1,0,0.1111,0,0,1,0,0.0,1,1,2004.0,7.6029005,2004.0,7.6029005,2004.0,7.6029005,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
44+
Nicaragua,NIC,5.0955892,5.23,7.54,1,0,0,0.1444,0,0,0,0,20.0,1,1,,,93.25,4.535284,130.0,4.8675346,5.0955892,4.8675346,0,0.044,0,93.25,4.535284,,
45+
New Zealand,NZL,2.1459312,9.7299995,9.7600002,0,1,0,0.45559999,1,0,0,1,91.699997,1,1,8.5500002,2.1459312,8.5500002,2.1459312,8.5500002,2.1459312,2.1459312,2.1459312,0,0.0,0,8.5500002,2.1459312,8.5500002,2.1459312
46+
Pakistan,PAK,3.6106477,6.0500002,7.3499999,1,0,0,0.33329999,0,1,0,0,0.0,1,1,36.990002,3.6106477,36.990002,3.6106477,36.990002,3.6106477,3.6106477,3.6106477,0,0.53757,0,36.990002,3.6106477,36.990002,3.6106477
47+
Panama,PAN,5.0955892,5.9099998,8.8400002,1,0,0,0.1,0,0,0,0,20.0,1,1,15.07,2.7127061,30.5,3.4177268,130.0,4.8675346,5.0955892,4.8675346,0,0.08004,0,30.5,3.4177268,15.07,2.7127061
48+
Peru,PER,4.2626801,5.77,8.3999996,1,0,0,0.1111,0,0,0,0,30.000002,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.2626801,4.0342407,0,0.00050000002,0,30.5,3.4177268,15.07,2.7127061
49+
Paraguay,PRY,4.3579903,6.9499998,8.21,1,0,0,0.25560001,0,0,0,0,25.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0,0,93.25,4.535284,,
50+
Sudan,SDN,4.4796071,4.0,7.3099999,1,1,0,0.16670001,0,0,1,0,0.0,1,1,88.199997,4.4796071,88.199997,4.4796071,88.199997,4.4796071,4.4796071,4.4796071,1,0.93099999,0,88.199997,4.4796071,88.199997,4.4796071
51+
Senegal,SEN,5.1038828,6.0,7.4000001,0,1,0,0.1556,0,0,1,0,0.0,0,1,164.66,5.1038828,164.66,5.1038828,164.66,5.1038828,5.1038828,5.1038828,1,0.94999999,0,164.66,5.1038828,164.66,5.1038828
52+
Singapore,SGP,2.8735647,9.3199997,10.15,0,0,0,0.0136,0,1,0,0,0.0,0,1,17.700001,2.8735647,17.700001,2.8735647,17.700001,2.8735647,2.8735647,2.8735647,0,0.0,0,17.700001,2.8735647,17.700001,2.8735647
53+
Sierra Leone,SLE,6.1800165,5.8200002,6.25,1,1,0,0.092200004,0,0,1,0,0.0,1,1,483.0,6.1800165,483.0,6.1800165,483.0,6.1800165,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
54+
El Salvador,SLV,4.3579903,5.0,7.9499998,1,0,0,0.15000001,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0,0,93.25,4.535284,,
55+
Togo,TGO,6.5042882,6.9099998,7.2199998,1,0,0,0.0889,0,0,1,0,0.0,1,1,668.0,6.5042882,668.0,6.5042882,668.0,6.5042882,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
56+
Trinidad and Tobago,TTO,4.4426513,7.4499998,8.7700005,0,1,0,0.1222,0,0,0,0,40.0,0,1,85.0,4.4426513,85.0,4.4426513,85.0,4.4426513,4.4426513,4.4426513,0,0.0,0,85.0,4.4426513,85.0,4.4426513
57+
Tunisia,TUN,4.1431346,6.4499998,8.4799995,1,1,0,0.37779999,0,0,1,0,0.0,1,1,63.0,4.1431346,63.0,4.1431346,63.0,4.1431346,4.1431346,4.1431346,0,0.0,0,63.0,4.1431346,63.0,4.1431346
58+
Tanzania,TZA,4.9767337,6.6399999,6.25,0,0,1,0.066699997,0,0,1,0,0.0,0,0,145.0,4.9767337,145.0,4.9767337,145.0,4.9767337,4.9767337,4.9767337,1,0.92150003,1,145.0,4.9767337,145.0,4.9767337
59+
Uganda,UGA,5.6347895,4.4499998,6.9699998,1,0,0,0.0111,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
60+
Uruguary,URY,4.2626801,7.0,9.0299997,1,0,0,0.36669999,0,0,0,0,90.0,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.0,0,93.25,4.535284,,
61+
USA,USA,2.7080503,10.0,10.22,0,1,0,0.42219999,1,0,0,0,83.600006,0,1,15.0,2.7080503,15.0,2.7080503,15.0,2.7080503,2.7080503,2.7080503,0,0.0,0,15.0,2.7080503,15.0,2.7080503
62+
Venezuela,VEN,4.3579903,7.1399999,9.0699997,1,0,0,0.0889,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0070400001,0,93.25,4.535284,,
63+
Vietnam,VNM,4.9416423,6.4099998,7.2800002,1,1,0,0.1778,0,1,0,0,0.0,1,1,140.0,4.9416423,140.0,4.9416423,140.0,4.9416423,4.9416423,4.9416423,0,0.70109999,0,140.0,4.9416423,140.0,4.9416423
64+
South Africa,ZAF,2.74084,6.8600001,8.8900003,0,1,0,0.3222,0,0,1,0,16.0,0,1,15.5,2.74084,15.5,2.74084,15.5,2.74084,2.74084,2.74084,0,0.1045,0,15.5,2.74084,15.5,2.74084
65+
Zaire,ZAR,5.480639,3.5,6.8699999,0,0,1,0.0,0,0,1,0,0.0,0,0,240.0,5.480639,240.0,5.480639,240.0,5.480639,5.480639,5.480639,1,0.94999999,1,240.0,5.480639,240.0,5.480639

causalpy/data/datasets.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"sc": {"filename": "synthetic_control.csv"},
1717
"anova1": {"filename": "ancova_generated.csv"},
1818
"geolift1": {"filename": "geolift1.csv"},
19+
"risk": {"filename": "AJR2001.csv"},
1920
}
2021

2122

causalpy/pymc_experiments.py

+128
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import seaborn as sns
99
import xarray as xr
1010
from patsy import build_design_matrices, dmatrices
11+
from sklearn.linear_model import LinearRegression as sk_lin_reg
1112

1213
from causalpy.custom_exceptions import BadIndexException # NOQA
1314
from causalpy.custom_exceptions import DataException, FormulaException
@@ -883,3 +884,130 @@ def _get_treatment_effect_coeff(self) -> str:
883884
return label
884885

885886
raise NameError("Unable to find coefficient name for the treatment effect")
887+
888+
889+
class InstrumentalVariable(ExperimentalDesign):
890+
"""
891+
A class to analyse instrumental variable style experiments.
892+
893+
:param instruments_data: A pandas dataframe of instruments
894+
for our treatment variable. Should contain
895+
instruments Z, and treatment t
896+
:param data: A pandas dataframe of covariates for fitting
897+
the focal regression of interest. Should contain covariates X
898+
including treatment t and outcome y
899+
:param instruments_formula: A statistical model formula for
900+
the instrumental stage regression
901+
e.g. t ~ 1 + z1 + z2 + z3
902+
:param formula: A statistical model formula for the \n
903+
focal regression e.g. y ~ 1 + t + x1 + x2 + x3
904+
:param model: A PyMC model
905+
:param priors: An optional dictionary of priors for the
906+
mus and sigmas of both regressions. If priors are not
907+
specified we will substitue MLE estimates for the beta
908+
coefficients. Greater control can be achieved
909+
by specifying the priors directly e.g. priors = {
910+
"mus": [0, 0],
911+
"sigmas": [1, 1],
912+
"eta": 2,
913+
"lkj_sd": 2,
914+
}
915+
916+
"""
917+
918+
def __init__(
919+
self,
920+
instruments_data: pd.DataFrame,
921+
data: pd.DataFrame,
922+
instruments_formula: str,
923+
formula: str,
924+
model=None,
925+
priors=None,
926+
**kwargs,
927+
):
928+
super().__init__(model=model, **kwargs)
929+
self.expt_type = "Instrumental Variable Regression"
930+
self.data = data
931+
self.instruments_data = instruments_data
932+
self.formula = formula
933+
self.instruments_formula = instruments_formula
934+
self.model = model
935+
self._input_validation()
936+
937+
y, X = dmatrices(formula, self.data)
938+
self._y_design_info = y.design_info
939+
self._x_design_info = X.design_info
940+
self.labels = X.design_info.column_names
941+
self.y, self.X = np.asarray(y), np.asarray(X)
942+
self.outcome_variable_name = y.design_info.column_names[0]
943+
944+
t, Z = dmatrices(instruments_formula, self.instruments_data)
945+
self._t_design_info = t.design_info
946+
self._z_design_info = Z.design_info
947+
self.labels_instruments = Z.design_info.column_names
948+
self.t, self.Z = np.asarray(t), np.asarray(Z)
949+
self.instrument_variable_name = t.design_info.column_names[0]
950+
951+
self.get_naive_OLS_fit()
952+
self.get_2SLS_fit()
953+
954+
# fit the model to the data
955+
COORDS = {"instruments": self.labels_instruments, "covariates": self.labels}
956+
self.coords = COORDS
957+
if priors is None:
958+
priors = {
959+
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
960+
"sigmas": [1, 1],
961+
"eta": 2,
962+
"lkj_sd": 2,
963+
}
964+
self.priors = priors
965+
self.model.fit(
966+
X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors
967+
)
968+
969+
def get_2SLS_fit(self):
970+
first_stage_reg = sk_lin_reg().fit(self.Z, self.t)
971+
fitted_Z_values = first_stage_reg.predict(self.Z)
972+
X2 = self.data.copy(deep=True)
973+
X2[self.instrument_variable_name] = fitted_Z_values
974+
_, X2 = dmatrices(self.formula, X2)
975+
second_stage_reg = sk_lin_reg().fit(X=X2, y=self.y)
976+
betas_first = list(first_stage_reg.coef_[0][1:])
977+
betas_first.insert(0, first_stage_reg.intercept_[0])
978+
betas_second = list(second_stage_reg.coef_[0][1:])
979+
betas_second.insert(0, second_stage_reg.intercept_[0])
980+
self.ols_beta_first_params = betas_first
981+
self.ols_beta_second_params = betas_second
982+
self.first_stage_reg = first_stage_reg
983+
self.second_stage_reg = second_stage_reg
984+
985+
def get_naive_OLS_fit(self):
986+
ols_reg = sk_lin_reg().fit(self.X, self.y)
987+
beta_params = list(ols_reg.coef_[0][1:])
988+
beta_params.insert(0, ols_reg.intercept_[0])
989+
self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
990+
self.ols_reg = ols_reg
991+
992+
def _input_validation(self):
993+
"""Validate the input data and model formula for correctness"""
994+
treatment = self.instruments_formula.split("~")[0]
995+
test = treatment.strip() in self.instruments_data.columns
996+
test = test & (treatment.strip() in self.data.columns)
997+
if not test:
998+
raise DataException(
999+
f"""
1000+
The treatment variable:
1001+
{treatment} must appear in the instrument_data to be used
1002+
as an outcome variable and in the data object to be used as a covariate.
1003+
"""
1004+
)
1005+
Z = self.data[treatment.strip()]
1006+
check_binary = len(np.unique(Z)) > 2
1007+
if check_binary:
1008+
warnings.warn(
1009+
"""Warning. The treatment variable is not Binary.
1010+
This is not necessarily a problem but it violates
1011+
the assumption of a simple IV experiment.
1012+
The coefficients should be interpreted appropriately."""
1013+
)

0 commit comments

Comments
 (0)