Skip to content

Commit 4e771e6

Browse files
null example of unrandomized CV
1 parent 701759f commit 4e771e6

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed

lasso_example_null_CV.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import functools
2+
3+
import numpy as np
4+
from scipy.stats import norm as ndist
5+
6+
import regreg.api as rr
7+
8+
from selection.tests.instance import gaussian_instance
9+
from knockoffs import lasso_glmnet
10+
11+
from core import (infer_full_target,
12+
split_sampler, # split_sampler not working yet
13+
normal_sampler,
14+
logit_fit,
15+
probit_fit)
16+
17+
def simulate(n=100, p=50, s=10, signal=(0, 0), sigma=2, alpha=0.1):
18+
19+
# description of statistical problem
20+
21+
X, y, truth = gaussian_instance(n=n,
22+
p=p,
23+
s=s,
24+
equicorrelated=False,
25+
rho=0.0,
26+
sigma=sigma,
27+
signal=signal,
28+
random_signs=True,
29+
scale=False)[:3]
30+
31+
XTX = X.T.dot(X)
32+
XTXi = np.linalg.inv(XTX)
33+
resid = y - X.dot(XTXi.dot(X.T.dot(y)))
34+
dispersion = np.linalg.norm(resid)**2 / (n-p)
35+
36+
S = X.T.dot(y)
37+
covS = dispersion * X.T.dot(X)
38+
smooth_sampler = normal_sampler(S, covS)
39+
splitting_sampler = split_sampler(X * y[:, None], covS)
40+
41+
def meta_algorithm(X, XTXi, resid, sampler):
42+
43+
S = sampler(scale=0.) # deterministic with scale=0
44+
ynew = X.dot(XTXi).dot(S) + resid # will be ok for n>p and non-degen X
45+
G = lasso_glmnet(X, ynew, *[None]*4)
46+
select = G.select()
47+
return set(list(select[0]))
48+
49+
selection_algorithm = functools.partial(meta_algorithm, X, XTXi, resid)
50+
51+
# run selection algorithm
52+
53+
observed_set = selection_algorithm(splitting_sampler)
54+
55+
# find the target, based on the observed outcome
56+
57+
# we just take the first target
58+
59+
pivots, covered, lengths = [], [], []
60+
naive_pivots, naive_covered, naive_lengths = [], [], []
61+
62+
for idx in list(observed_set)[:1]:
63+
print("variable: ", idx, "total selected: ", len(observed_set))
64+
true_target = truth[idx]
65+
66+
(pivot,
67+
interval) = infer_full_target(selection_algorithm,
68+
observed_set,
69+
idx,
70+
splitting_sampler,
71+
dispersion,
72+
hypothesis=true_target,
73+
fit_probability=probit_fit,
74+
alpha=alpha,
75+
B=500)
76+
77+
pivots.append(pivot)
78+
covered.append((interval[0] < true_target) * (interval[1] > true_target))
79+
lengths.append(interval[1] - interval[0])
80+
81+
target_sd = np.sqrt(dispersion * XTXi[idx, idx])
82+
observed_target = np.squeeze(XTXi[idx].dot(X.T.dot(y)))
83+
quantile = ndist.ppf(1 - 0.5 * alpha)
84+
naive_interval = (observed_target-quantile * target_sd, observed_target+quantile * target_sd)
85+
naive_pivots.append((1-ndist.cdf((observed_target-true_target)/target_sd))) # one-sided
86+
87+
naive_covered.append((naive_interval[0]<true_target)*(naive_interval[1]>true_target))
88+
naive_lengths.append(naive_interval[1]-naive_interval[0])
89+
90+
return pivots, covered, lengths, naive_pivots, naive_covered, naive_lengths
91+
92+
93+
if __name__ == "__main__":
94+
import statsmodels.api as sm
95+
import matplotlib.pyplot as plt
96+
97+
np.random.seed(1)
98+
99+
U = np.linspace(0, 1, 101)
100+
P, L, coverage = [], [], []
101+
naive_P, naive_L, naive_coverage = [], [], []
102+
plt.clf()
103+
for i in range(500):
104+
p, cover, l, naive_p, naive_covered, naive_l = simulate()
105+
coverage.extend(cover)
106+
P.extend(p)
107+
L.extend(l)
108+
naive_P.extend(naive_p)
109+
naive_coverage.extend(naive_covered)
110+
naive_L.extend(naive_l)
111+
112+
print("selective:", np.mean(P), np.std(P), np.mean(L) , np.mean(coverage))
113+
print("naive:", np.mean(naive_P), np.std(naive_P), np.mean(naive_L), np.mean(naive_coverage))
114+
print("len ratio selective divided by naive:", np.mean(np.array(L) / np.array(naive_L)))
115+
116+
if i % 2 == 0 and i > 0:
117+
plt.clf()
118+
plt.plot(U, sm.distributions.ECDF(P)(U), 'r', label='Selective', linewidth=3)
119+
plt.plot([0,1], [0,1], 'k--', linewidth=2)
120+
plt.plot(U, sm.distributions.ECDF(naive_P)(U), 'b', label='Naive', linewidth=3)
121+
plt.legend()
122+
plt.savefig('lasso_example_null_CV.pdf')

0 commit comments

Comments
 (0)