-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
192 lines (154 loc) · 7.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import argparse
import copy
import logging
import os
from functools import partial
import jax
import numpy as np
from matplotlib import rcParams
from learnware.market import easy
from tqdm import tqdm
from benchmark import best_match_performance, average_performance_totally
from build_market import build_from_preprocessed, upload_to_easy_market
from evaluate_market import evaluate_market_performance
from diagram.plot_accuracy import plot_accuracy_diagram, load_users
from diagram.plot_spec import load_market, plot_comparison_diagram
from preprocess.split_data import generate
from preprocess.train_model import train_model
from utils import ntk_rkme
from utils.clerk import get_custom_logger, Clerk
parser = argparse.ArgumentParser(description='NTK-RF Experiments Remake')
# AUTO_PARAM = "data_id"
parser.add_argument('--mode', type=str, default="regular")
parser.add_argument('--token', default=None, help='Used for auto.bash')
parser.add_argument('--auto_param', type=str, default=None, help='search param in auto model, None for regular mode')
# train
parser.add_argument('--cuda_idx', type=int, default=0,
help='ID of device')
parser.add_argument('--no_reduce', default=False, action=argparse.BooleanOptionalAction, help='whether to reduce')
# learnware
parser.add_argument('--id', type=int, default=0,
help='Used for parallel training')
parser.add_argument('--spec', type=str, default='ntk',
help='Specification, options: [rbf, NTK]')
parser.add_argument('--market_root', type=str, default='market',
help='Path of Market')
parser.add_argument('--max_search_num', type=int, default=3,
help='Number of Max Search Learnware to ensemble')
parser.add_argument('-K', type=int, default=50,
help='number of reduced points')
# data
parser.add_argument('--resplit', default=False, action=argparse.BooleanOptionalAction,
help='Resplit datasets')
parser.add_argument('--regenerate', default=True, action=argparse.BooleanOptionalAction,
help='whether to regenerate specs and learnwares')
parser.add_argument('--data', type=str, default='cifar10', help='dataset type')
parser.add_argument('--image_size', type=int, default=32)
parser.add_argument('--data_root', type=str, default=r"image_models",
help='The path of images and models')
parser.add_argument('--n_uploaders', type=int, default=50, help='Number of uploaders')
parser.add_argument('--n_users', type=int, default=50, help='Number of users')
parser.add_argument('--data_id', type=int, default=0, help='market data id')
#ntk
parser.add_argument('--model', type=str,
default="conv", help='The model used to generate random features')
parser.add_argument('--n_models', type=int, default=16,
help='# of random models')
parser.add_argument('--n_random_features', type=int, default=4096,
help='out features of random model')
parser.add_argument('--net_width', type=int, default=128,
help='# of inner channels of random model')
parser.add_argument('--net_depth', type=int, default=3,
help='network depth of conv')
parser.add_argument('--activation', type=str,
default='relu', help='activation of random model')
parser.add_argument('--ntk_steps', type=int,
default=100, help='steps of optimization')
parser.add_argument('--ntk_factor', type=float, # TODO: Why this is not working ?
default=1, help='factor of steps of optimization')
parser.add_argument('--sigma', type=float, default=None, help='standard variance of random models')
args = parser.parse_args()
CANDIDATES = {
"model": ['conv', 'resnet'],
"ntk_steps": [70, 80, 90, 100, 110, 120, 130],
# "ntk_factor": [0.8, 0.9, 1.0, 1.05, 1.1, 1.2, 1.3, 1.4],
"sigma": [0.003, 0.004, 0.005, 0.006, 0.01, 0.025, 0.05, 0.1],
"n_random_features": [32, 64, 96, 128, 196, 256],
"net_width": [32, 64, 96, 128, 160, 196],
"data_id": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"net_depth": [3, 3, 4, 4, 5, 5, 6, 6]
}
def _regular_mode(clerk=None):
if args.resplit:
_re_split_mode()
learnware_list = build_from_preprocessed(args, regenerate=args.regenerate)
market = upload_to_easy_market(args, learnware_list)
evaluate_market_performance(args, market, clerk=clerk, regenerate=args.regenerate)
jax.clear_backends()
best_match_performance(args, clerk=clerk)
logger = get_custom_logger()
logger.info("=" * 45)
for k, v in args.__dict__.items():
logger.info("{:<10}:{}".format(k, v))
logger.info("=" * 45)
def _re_split_mode():
setattr(args, "data_id", CANDIDATES["data_id"][args.id])
generate(args)
train_model(args)
best_match_performance(args)
def _auto_mode(search_key, clerk=None):
logger = get_custom_logger()
available_cuda_idx = [0, 1, 2, 3, 4, 5, 6, 7]
if search_key is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(available_cuda_idx[args.id % len(available_cuda_idx)])
if args.id >= len(CANDIDATES[search_key]):
return
setattr(args, search_key, CANDIDATES[search_key][args.id])
logger.info("=" * 45)
for k, v in args.__dict__.items():
logger.info("{:<10}:{}".format(k, v))
logger.info("=" * 45)
_regular_mode(clerk=clerk)
print(ntk_rkme.RKMEStatSpecification.INNER_PRODUCT_COUNT)
def _plot_spec_mode():
rbf_market, ntk_market = load_market(args)
plot_comparison_diagram(args, 10, rbf_market, ntk_market)
def _plot_accuracy_mode():
rcParams['font.family'] = 'serif'
rcParams['font.serif'] = 'SimHei'
rbf_market, ntk_market = load_market(args)
rbf_specs, ntk_specs = load_users(args)
plot_accuracy_diagram(args, rbf_market, ntk_market, rbf_specs, ntk_specs)
def _average_performance_mode():
average_performance_totally(args, list(range(8)), list(range(8)))
def _oracle_performance_mode(num=8):
accuracies = []
for id_, data_id in tqdm(zip(range(num), range(num)), total=num):
args_ = copy.deepcopy(args)
args_.data_id = data_id
args_.id = id_
clerk = Clerk()
best_match_performance(args_, clerk)
accuracies.append(np.mean(clerk.best))
accuracies = np.asarray(accuracies)
print("Oracle Case({:d}): {:.5f} {:.5f}".format(args.max_search_num,
np.mean(accuracies), np.std(accuracies)))
print(" ".join(["{:.5f}".format(v) for v in accuracies]))
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_idx)
args.cuda_idx = 0
easy.logger.setLevel(logging.WARNING)
performance_clerk = Clerk()
behaviour_by_mode = {
"resplit": _re_split_mode,
"regular": partial(_regular_mode, clerk=performance_clerk),
"auto": partial(_auto_mode, args.auto_param, clerk=performance_clerk),
"plot_spec": _plot_spec_mode,
"plot_accuracy": _plot_accuracy_mode,
"average_performance": _average_performance_mode,
"oracle_performance": _oracle_performance_mode
}
if args.mode not in behaviour_by_mode:
raise NotImplementedError()
behaviour_by_mode[args.mode]()
print(performance_clerk)