Skip to content

Commit 701759f

Browse files
made new learner klass, BF to size of target_sd
1 parent b354d3a commit 701759f

File tree

2 files changed

+23
-35
lines changed

2 files changed

+23
-35
lines changed

core.py

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
probit_fit)
1414
from samplers import (normal_sampler,
1515
split_sampler)
16+
from learners import mixture_learner
1617

1718
def infer_general_target(algorithm,
1819
observed_outcome,
@@ -25,7 +26,8 @@ def infer_general_target(algorithm,
2526
hypothesis=0,
2627
alpha=0.1,
2728
success_params=(1, 1),
28-
B=500):
29+
B=500,
30+
learner_klass=mixture_learner):
2931
'''
3032
3133
Compute a p-value (or pivot) for a target having observed `outcome` of `algorithm(observed_sampler)`.
@@ -64,23 +66,16 @@ def infer_general_target(algorithm,
6466
How many queries?
6567
'''
6668

67-
target_sd = np.sqrt(target_cov[0, 0])
68-
69-
# could use an improvement here...
70-
71-
def learning_proposal(sd=target_sd, center=observed_target):
72-
scale = np.random.choice([0.25, 0.5, 1, 1.5, 2, 3], 1)
73-
return np.random.standard_normal() * sd * scale + center
74-
75-
weight_fn = learn_weights(algorithm,
76-
observed_outcome,
77-
observed_sampler,
78-
observed_target,
79-
target_cov,
80-
cross_cov,
81-
learning_proposal,
82-
fit_probability,
69+
learner = learner_klass(algorithm,
70+
observed_set,
71+
observed_sampler,
72+
observed_target,
73+
target_cov,
74+
cross_cov)
75+
76+
weight_fn = learner.learn(fit_probability,
8377
fit_args=fit_args,
78+
check_selection=None,
8479
B=B)
8580

8681
return _inference(observed_target,
@@ -100,7 +95,8 @@ def infer_full_target(algorithm,
10095
hypothesis=0,
10196
alpha=0.1,
10297
success_params=(1, 1),
103-
B=500):
98+
B=500,
99+
learner_klass=mixture_learner):
104100

105101
'''
106102
@@ -154,22 +150,14 @@ def infer_full_target(algorithm,
154150
if feature not in observed_set:
155151
raise ValueError('for full target, we can only do inference for features observed in the outcome')
156152

157-
target_sd = np.sqrt(target_cov[0, 0])
158-
159-
# could use an improvement here...
160-
161-
def learning_proposal(sd=target_sd, center=observed_target):
162-
scale = np.random.choice([0.5, 1, 1.5, 2], 1)
163-
return np.random.standard_normal() * sd * scale + center
164-
165-
weight_fn = learn_weights(algorithm,
166-
observed_set,
167-
observed_sampler,
168-
observed_target,
169-
target_cov,
170-
cross_cov,
171-
learning_proposal,
172-
fit_probability,
153+
learner = learner_klass(algorithm,
154+
observed_set,
155+
observed_sampler,
156+
observed_target,
157+
target_cov,
158+
cross_cov)
159+
160+
weight_fn = learner.learn(fit_probability,
173161
fit_args=fit_args,
174162
check_selection=lambda result: feature in set(result),
175163
B=B)

learners.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self,
6666
cross_cov)
6767

6868
def learning_proposal(self):
69-
sd = np.sqrt(self.target_cov)
69+
sd = np.sqrt(self.target_cov[0, 0])
7070
center = self.observed_target
7171
scale = np.random.choice([0.5, 1, 1.5, 2], 1)
7272
return np.random.standard_normal() * sd * scale + center

0 commit comments

Comments
 (0)