13
13
probit_fit )
14
14
from samplers import (normal_sampler ,
15
15
split_sampler )
16
+ from learners import mixture_learner
16
17
17
18
def infer_general_target (algorithm ,
18
19
observed_outcome ,
@@ -25,7 +26,8 @@ def infer_general_target(algorithm,
25
26
hypothesis = 0 ,
26
27
alpha = 0.1 ,
27
28
success_params = (1 , 1 ),
28
- B = 500 ):
29
+ B = 500 ,
30
+ learner_klass = mixture_learner ):
29
31
'''
30
32
31
33
Compute a p-value (or pivot) for a target having observed `outcome` of `algorithm(observed_sampler)`.
@@ -64,23 +66,16 @@ def infer_general_target(algorithm,
64
66
How many queries?
65
67
'''
66
68
67
- target_sd = np .sqrt (target_cov [0 , 0 ])
68
-
69
- # could use an improvement here...
70
-
71
- def learning_proposal (sd = target_sd , center = observed_target ):
72
- scale = np .random .choice ([0.25 , 0.5 , 1 , 1.5 , 2 , 3 ], 1 )
73
- return np .random .standard_normal () * sd * scale + center
74
-
75
- weight_fn = learn_weights (algorithm ,
76
- observed_outcome ,
77
- observed_sampler ,
78
- observed_target ,
79
- target_cov ,
80
- cross_cov ,
81
- learning_proposal ,
82
- fit_probability ,
69
+ learner = learner_klass (algorithm ,
70
+ observed_set ,
71
+ observed_sampler ,
72
+ observed_target ,
73
+ target_cov ,
74
+ cross_cov )
75
+
76
+ weight_fn = learner .learn (fit_probability ,
83
77
fit_args = fit_args ,
78
+ check_selection = None ,
84
79
B = B )
85
80
86
81
return _inference (observed_target ,
@@ -100,7 +95,8 @@ def infer_full_target(algorithm,
100
95
hypothesis = 0 ,
101
96
alpha = 0.1 ,
102
97
success_params = (1 , 1 ),
103
- B = 500 ):
98
+ B = 500 ,
99
+ learner_klass = mixture_learner ):
104
100
105
101
'''
106
102
@@ -154,22 +150,14 @@ def infer_full_target(algorithm,
154
150
if feature not in observed_set :
155
151
raise ValueError ('for full target, we can only do inference for features observed in the outcome' )
156
152
157
- target_sd = np .sqrt (target_cov [0 , 0 ])
158
-
159
- # could use an improvement here...
160
-
161
- def learning_proposal (sd = target_sd , center = observed_target ):
162
- scale = np .random .choice ([0.5 , 1 , 1.5 , 2 ], 1 )
163
- return np .random .standard_normal () * sd * scale + center
164
-
165
- weight_fn = learn_weights (algorithm ,
166
- observed_set ,
167
- observed_sampler ,
168
- observed_target ,
169
- target_cov ,
170
- cross_cov ,
171
- learning_proposal ,
172
- fit_probability ,
153
+ learner = learner_klass (algorithm ,
154
+ observed_set ,
155
+ observed_sampler ,
156
+ observed_target ,
157
+ target_cov ,
158
+ cross_cov )
159
+
160
+ weight_fn = learner .learn (fit_probability ,
173
161
fit_args = fit_args ,
174
162
check_selection = lambda result : feature in set (result ),
175
163
B = B )
0 commit comments