Skip to content

Commit b354d3a

Browse files
new prototype for learners
1 parent d7e4221 commit b354d3a

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

learners.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import functools
2+
3+
import numpy as np
4+
from scipy.stats import norm as ndist
5+
6+
from selection.distributions.discrete_family import discrete_family
7+
8+
from samplers import normal_sampler
9+
10+
class mixture_learner(object):
11+
12+
def __init__(self,
13+
algorithm,
14+
observed_outcome,
15+
observed_sampler,
16+
observed_target,
17+
target_cov,
18+
cross_cov):
19+
20+
"""
21+
Learn a function
22+
23+
P(Y=1|T, N=S-c*T)
24+
25+
where N is the sufficient statistic corresponding to nuisance parameters and T is our target.
26+
The random variable Y is
27+
28+
Y = check_selection(algorithm(new_sampler))
29+
30+
That is, we perturb the center of observed_sampler along a ray (or higher-dimensional affine
31+
subspace) and rerun the algorithm, checking to see if the test `check_selection` passes.
32+
33+
For full model inference, `check_selection` will typically check to see if a given feature
34+
is still in the selected set. For general targets, we will typically condition on the exact observed value
35+
of `algorithm(observed_sampler)`.
36+
37+
Parameters
38+
----------
39+
40+
algorithm : callable
41+
Selection algorithm that takes a noise source as its only argument.
42+
43+
observed_set : set(int)
44+
The purported value `algorithm(observed_sampler)`, i.e. run with the original seed.
45+
46+
feature : int
47+
One of the elements of observed_set.
48+
49+
observed_sampler : `normal_source`
50+
Representation of the data used in the selection procedure.
51+
52+
learning_proposal : callable
53+
Proposed position of new T to add to evaluate algorithm at.
54+
"""
55+
56+
(self.algorithm,
57+
self.observed_outcome,
58+
self.observed_sampler,
59+
self.observed_target,
60+
self.target_cov,
61+
self.cross_cov) = (algorithm,
62+
observed_outcome,
63+
observed_sampler,
64+
observed_target,
65+
target_cov,
66+
cross_cov)
67+
68+
def learning_proposal(self):
69+
sd = np.sqrt(self.target_cov)
70+
center = self.observed_target
71+
scale = np.random.choice([0.5, 1, 1.5, 2], 1)
72+
return np.random.standard_normal() * sd * scale + center
73+
74+
def learn(self,
75+
fit_probability,
76+
fit_args = {},
77+
B=500,
78+
check_selection=None):
79+
80+
"""
81+
fit_probability : callable
82+
Function to learn a probability model P(Y=1|T) based on [T, Y].
83+
84+
fit_args : dict
85+
Keyword arguments to `fit_probability`.
86+
87+
B : int
88+
How many queries?
89+
90+
check_selection : callable (optional)
91+
Callable that determines selection variable.
92+
93+
"""
94+
95+
(algorithm,
96+
observed_outcome,
97+
observed_sampler,
98+
observed_target,
99+
target_cov,
100+
cross_cov) = (self.algorithm,
101+
self.observed_outcome,
102+
self.observed_sampler,
103+
self.observed_target,
104+
self.target_cov,
105+
self.cross_cov)
106+
107+
S = selection_stat = observed_sampler.center
108+
109+
new_sampler = normal_sampler(observed_sampler.center.copy(),
110+
observed_sampler.covariance.copy())
111+
112+
if check_selection is None:
113+
check_selection = lambda result: result == observed_outcome
114+
115+
direction = cross_cov.dot(np.linalg.inv(target_cov).reshape((1,1))) # move along a ray through S with this direction
116+
117+
learning_Y, learning_T = [], []
118+
119+
def random_meta_algorithm(new_sampler, algorithm, check_selection, T):
120+
new_sampler.center = S + direction.dot(T - observed_target)
121+
new_result = algorithm(new_sampler)
122+
return check_selection(new_result)
123+
124+
random_algorithm = functools.partial(random_meta_algorithm, new_sampler, algorithm, check_selection)
125+
126+
# this is the "active learning bit"
127+
# START
128+
129+
for _ in range(B):
130+
T = self.learning_proposal() # a guess at informative distribution for learning what we want
131+
Y = random_algorithm(T)
132+
133+
learning_Y.append(Y)
134+
learning_T.append(T)
135+
136+
learning_Y = np.array(learning_Y, np.float)
137+
learning_T = np.squeeze(np.array(learning_T, np.float))
138+
139+
print('prob(select): ', np.mean(learning_Y))
140+
conditional_law = fit_probability(learning_T, learning_Y, **fit_args)
141+
return conditional_law
142+

0 commit comments

Comments
 (0)