Skip to content

Commit 21751b8

Browse files
change that will be undone
1 parent 739624d commit 21751b8

File tree

1 file changed

+345
-0
lines changed

1 file changed

+345
-0
lines changed

Diff for: core.py

+345
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
from copy import copy
2+
import functools
3+
4+
import numpy as np
5+
from scipy.stats import norm as ndist
6+
from scipy.stats import binom
7+
8+
from selection.distributions.discrete_family import discrete_family
9+
10+
# local imports
11+
12+
from fitters import (logit_fit,
13+
probit_fit)
14+
from samplers import (normal_sampler,
15+
split_sampler)
16+
from learners import mixture_learner
17+
18+
def infer_general_target(algorithm,
19+
observed_outcome,
20+
observed_sampler,
21+
observed_target,
22+
cross_cov,
23+
target_cov,
24+
fit_probability=probit_fit,
25+
fit_args={},
26+
hypothesis=0,
27+
alpha=0.1,
28+
success_params=(1, 1),
29+
B=500,
30+
learner_klass=mixture_learner):
31+
'''
32+
33+
Compute a p-value (or pivot) for a target having observed `outcome` of `algorithm(observed_sampler)`.
34+
35+
Parameters
36+
----------
37+
38+
algorithm : callable
39+
Selection algorithm that takes a noise source as its only argument.
40+
41+
observed_outcome : object
42+
The purported value `algorithm(observed_sampler)`, i.e. run with the original seed.
43+
44+
observed_sampler : `normal_source`
45+
Representation of the data used in the selection procedure.
46+
47+
cross_cov : np.float((*,1)) # 1 for 1-dimensional targets for now
48+
Covariance between `observed_sampler.center` and target estimator.
49+
50+
target_cov : np.float((1, 1)) # 1 for 1-dimensional targets for now
51+
Covariance of target estimator
52+
53+
observed_target : np.float # 1-dimensional targets for now
54+
Observed value of target estimator.
55+
56+
fit_probability : callable
57+
Function to learn a probability model P(Y=1|T) based on [T, Y].
58+
59+
hypothesis : np.float # 1-dimensional targets for now
60+
Hypothesized value of target.
61+
62+
alpha : np.float
63+
Level for 1 - confidence.
64+
65+
B : int
66+
How many queries?
67+
'''
68+
69+
learner = learner_klass(algorithm,
70+
observed_set,
71+
observed_sampler,
72+
observed_target,
73+
target_cov,
74+
cross_cov)
75+
76+
weight_fn = learner.learn(fit_probability,
77+
fit_args=fit_args,
78+
check_selection=None,
79+
B=B)
80+
81+
return _inference(observed_target,
82+
target_cov,
83+
weight_fn,
84+
hypothesis=hypothesis,
85+
alpha=alpha,
86+
success_params=success_params)
87+
88+
def infer_full_target(algorithm,
89+
observed_set,
90+
feature,
91+
observed_sampler,
92+
dispersion, # sigma^2
93+
fit_probability=probit_fit,
94+
fit_args={},
95+
hypothesis=0,
96+
alpha=0.1,
97+
success_params=(1, 1),
98+
B=500,
99+
learner_klass=mixture_learner):
100+
101+
'''
102+
103+
Compute a p-value (or pivot) for a target having observed `outcome` of `algorithm(observed_sampler)`.
104+
105+
Parameters
106+
----------
107+
108+
algorithm : callable
109+
Selection algorithm that takes a noise source as its only argument.
110+
111+
observed_set : set(int)
112+
The purported value `algorithm(observed_sampler)`, i.e. run with the original seed.
113+
114+
feature : int
115+
One of the elements of observed_set.
116+
117+
observed_sampler : `normal_source`
118+
Representation of the data used in the selection procedure.
119+
120+
fit_probability : callable
121+
Function to learn a probability model P(Y=1|T) based on [T, Y].
122+
123+
hypothesis : np.float # 1-dimensional targets for now
124+
Hypothesized value of target.
125+
126+
alpha : np.float
127+
Level for 1 - confidence.
128+
129+
B : int
130+
How many queries?
131+
132+
Notes
133+
-----
134+
135+
This function makes the assumption that covariance in observed sampler is the
136+
true covariance of S and we are looking for inference about coordinates of the mean of
137+
138+
np.linalg.inv(covariance).dot(S)
139+
140+
this allows us to compute the required observed_target, cross_cov and target_cov.
141+
142+
'''
143+
144+
info_inv = np.linalg.inv(observed_sampler.covariance / dispersion) # scale free, i.e. X.T.dot(X) without sigma^2
145+
target_cov = (info_inv[feature, feature] * dispersion).reshape((1, 1))
146+
observed_target = np.squeeze(info_inv[feature].dot(observed_sampler.center))
147+
cross_cov = observed_sampler.covariance.dot(info_inv[feature]).reshape((-1,1))
148+
149+
observed_set = set(observed_set)
150+
if feature not in observed_set:
151+
raise ValueError('for full target, we can only do inference for features observed in the outcome')
152+
153+
learner = learner_klass(algorithm,
154+
observed_set,
155+
observed_sampler,
156+
observed_target,
157+
target_cov,
158+
cross_cov)
159+
160+
weight_fn = learner.learn(fit_probability,
161+
fit_args=fit_args,
162+
check_selection=lambda result: feature in set(result),
163+
B=B)
164+
165+
return _inference(observed_target,
166+
target_cov,
167+
weight_fn,
168+
hypothesis=hypothesis,
169+
alpha=alpha,
170+
success_params=success_params)
171+
172+
173+
def learn_weights(algorithm,
174+
observed_outcome,
175+
observed_sampler,
176+
observed_target,
177+
target_cov,
178+
cross_cov,
179+
learning_proposal,
180+
fit_probability,
181+
fit_args={},
182+
B=500,
183+
check_selection=None):
184+
"""
185+
Learn a function
186+
187+
P(Y=1|T, N=S-c*T)
188+
189+
where N is the sufficient statistic corresponding to nuisance parameters and T is our target.
190+
The random variable Y is
191+
192+
Y = check_selection(algorithm(new_sampler))
193+
194+
That is, we perturb the center of observed_sampler along a ray (or higher-dimensional affine
195+
subspace) and rerun the algorithm, checking to see if the test `check_selection` passes.
196+
197+
For full model inference, `check_selection` will typically check to see if a given feature
198+
is still in the selected set. For general targets, we will typically condition on the exact observed value
199+
of `algorithm(observed_sampler)`.
200+
201+
Parameters
202+
----------
203+
204+
algorithm : callable
205+
Selection algorithm that takes a noise source as its only argument.
206+
207+
observed_set : set(int)
208+
The purported value `algorithm(observed_sampler)`, i.e. run with the original seed.
209+
210+
feature : int
211+
One of the elements of observed_set.
212+
213+
observed_sampler : `normal_source`
214+
Representation of the data used in the selection procedure.
215+
216+
learning_proposal : callable
217+
Proposed position of new T to add to evaluate algorithm at.
218+
219+
fit_probability : callable
220+
Function to learn a probability model P(Y=1|T) based on [T, Y].
221+
222+
B : int
223+
How many queries?
224+
225+
"""
226+
S = selection_stat = observed_sampler.center
227+
228+
new_sampler = normal_sampler(observed_sampler.center.copy(),
229+
observed_sampler.covariance.copy())
230+
231+
if check_selection is None:
232+
check_selection = lambda result: result == observed_outcome
233+
234+
direction = cross_cov.dot(np.linalg.inv(target_cov).reshape((1,1))) # move along a ray through S with this direction
235+
236+
learning_Y, learning_T = [], []
237+
238+
def random_meta_algorithm(new_sampler, algorithm, check_selection, T):
239+
new_sampler.center = S + direction.dot(T - observed_target)
240+
new_result = algorithm(new_sampler)
241+
return check_selection(new_result)
242+
243+
random_algorithm = functools.partial(random_meta_algorithm, new_sampler, algorithm, check_selection)
244+
245+
# this is the "active learning bit"
246+
# START
247+
248+
for _ in range(B):
249+
T = learning_proposal() # a guess at informative distribution for learning what we want
250+
Y = random_algorithm(T)
251+
252+
learning_Y.append(Y)
253+
learning_T.append(T)
254+
255+
learning_Y = np.array(learning_Y, np.float)
256+
learning_T = np.squeeze(np.array(learning_T, np.float))
257+
258+
print('prob(select): ', np.mean(learning_Y))
259+
conditional_law = fit_probability(learning_T, learning_Y, **fit_args)
260+
return conditional_law
261+
262+
# Private functions
263+
264+
def _inference(observed_target,
265+
target_cov,
266+
weight_fn, # our fitted function
267+
success_params=(1, 1),
268+
hypothesis=0,
269+
alpha=0.1):
270+
271+
'''
272+
273+
Produce p-values (or pivots) and confidence intervals having estimated a weighting function.
274+
275+
The basic object here is a 1-dimensional exponential family with reference density proportional
276+
to
277+
278+
lambda t: scipy.stats.norm.pdf(t / np.sqrt(target_cov)) * weight_fn(t)
279+
280+
Parameters
281+
----------
282+
283+
observed_target : float
284+
285+
target_cov : np.float((1, 1))
286+
287+
hypothesis : float
288+
Hypothesised true mean of target.
289+
290+
alpha : np.float
291+
Level for 1 - confidence.
292+
293+
Returns
294+
-------
295+
296+
pivot : float
297+
Probability integral transform of the observed_target at mean parameter "hypothesis"
298+
299+
confidence_interval : (float, float)
300+
(1 - alpha) * 100% confidence interval.
301+
302+
'''
303+
304+
k, m = success_params # need at least k of m successes
305+
306+
target_sd = np.sqrt(target_cov[0, 0])
307+
308+
target_val = np.linspace(-20 * target_sd, 20 * target_sd, 5001) + observed_target
309+
310+
if (k, m) != (1, 1):
311+
weight_val = np.array([binom(m, p).sf(k-1) for p in weight_fn(target_val)])
312+
else:
313+
weight_val = weight_fn(target_val)
314+
315+
weight_val *= ndist.pdf(target_val / target_sd)
316+
exp_family = discrete_family(target_val, weight_val)
317+
318+
pivot = exp_family.cdf(hypothesis / target_cov[0, 0], x=observed_target)
319+
pivot = 2 * min(pivot, 1-pivot)
320+
321+
interval = exp_family.equal_tailed_interval(observed_target, alpha=alpha)
322+
rescaled_interval = (interval[0] * target_cov[0, 0], interval[1] * target_cov[0, 0])
323+
324+
return pivot, rescaled_interval # TODO: should do MLE as well does discrete_family do this?
325+
326+
def repeat_selection(base_algorithm, sampler, min_success, num_tries):
327+
"""
328+
Repeat a set-returning selection algorithm `num_tries` times,
329+
returning all elements that appear at least `min_success` times.
330+
"""
331+
332+
results = {}
333+
334+
for _ in range(num_tries):
335+
current = base_algorithm(sampler)
336+
for item in current:
337+
results.setdefault(item, 0)
338+
results[item] += 1
339+
340+
final_value = []
341+
for key in results:
342+
if results[key] >= min_success:
343+
final_value.append(key)
344+
345+
return set(final_value)

0 commit comments

Comments
 (0)