-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfo.py
91 lines (73 loc) · 3.37 KB
/
info.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# load all_info pickle file and generate graphs as necessary
import pickle
import numpy as np
import matplotlib.pyplot as plt
info_file = 'adversarial/all_info_exp6.pickle'
num_epsilons = 10
epsilons = np.linspace(0, 0.2, num=num_epsilons)
# results file - that stores best model for each value of epsilon
best_models_path = 'adversarial/best_models_exp6.pickle'
best_models = {}
with open(info_file, 'rb') as read_file:
info = pickle.load(read_file)
normal = info['normal']
best_normal = normal['normal_0.0_0.0'].cpu().numpy()
best_models_normal = ['normal_0.0_0.0'] * num_epsilons
best_models['normal'] = best_models_normal
blackout = info['blackout']
best_blackout = blackout['blackout_0.0_0.0'].cpu().numpy()
best_models_blackout = ['blackout_0.0_0.0'] * num_epsilons
best_models['blackout'] = best_models_blackout
lambda_vary = info['lambda_vary']
best_lambda_vary = [float(0)] * num_epsilons
best_models_lambda_vary = ['None'] * num_epsilons
for model_name, robust_acc in lambda_vary.items():
robust_acc = robust_acc.cpu().numpy()
for idx, item in enumerate(robust_acc):
if item >= best_lambda_vary[idx]:
best_lambda_vary[idx] = item
best_models_lambda_vary[idx] = model_name
lambda1_zero = info['lambda1_zero']
best_lambda1_zero = [float(0)] * num_epsilons
best_models_lambda1_zero = ['None'] * num_epsilons
for model_name, robust_acc in lambda1_zero.items():
robust_acc = robust_acc.cpu().numpy()
for idx, item in enumerate(robust_acc):
if item >= best_lambda1_zero[idx]:
best_lambda1_zero[idx] = item
best_models_lambda1_zero[idx] = model_name
# because lambda_vary includes lambda1_zero
if item >= best_lambda_vary[idx]:
best_lambda_vary[idx] = item
best_models_lambda_vary[idx] = model_name
best_lambda1_zero = np.array(best_lambda1_zero)
best_models['lambda1_zero'] = best_models_lambda1_zero
best_lambda_equal = [float(0)] * num_epsilons
best_models_lambda_equal = ['None'] * num_epsilons
lambda_equal = info['lambda_equal']
for model_name, robust_acc in lambda_equal.items():
robust_acc = robust_acc.cpu().numpy()
for idx, item in enumerate(robust_acc):
if item >= best_lambda_equal[idx]:
best_lambda_equal[idx] = item
best_models_lambda_equal[idx] = model_name
# because lambda_vary includes lambda_equal
if item >= best_lambda_vary[idx]:
best_lambda_vary[idx] = item
best_models_lambda_vary[idx] = model_name
best_lambda_equal = np.array(best_lambda_equal)
best_models['lambda_equal'] = best_models_lambda_equal
# convert for best_vary here
best_lambda_vary = np.array(best_lambda_vary)
best_models['lambda_vary'] = best_models_lambda_vary
plt.plot(epsilons, best_normal, label='normal')
plt.plot(epsilons, best_blackout, label='blackout')
plt.plot(epsilons, best_lambda1_zero, label='lambda1_zero')
plt.plot(epsilons, best_lambda_equal, label='lambda_equal')
plt.plot(epsilons, best_lambda_vary, label='lambda_vary')
plt.legend()
plt.show()
# save the best_models as a pickle file
with open(best_models_path, 'wb') as write_file:
pickle.dump(best_models, write_file)
print('Done')