|
1 |
| -from models.selector import * |
2 |
| -from utils.util import * |
3 |
| -from data_loader import * |
4 |
| -from torch.utils.data import DataLoader |
5 |
| -from config import get_arguments |
6 |
| -from tqdm import tqdm |
| 1 | +import torch |
7 | 2 |
|
8 | 3 |
|
9 |
| -def compute_loss_value(opt, poisoned_data, model_ascent): |
10 |
| - # Calculate loss value per example |
11 |
| - # Define loss function |
12 |
| - if opt.cuda: |
13 |
| - criterion = nn.CrossEntropyLoss().cuda() |
14 |
| - else: |
15 |
| - criterion = nn.CrossEntropyLoss() |
| 4 | +def min_max_normalization(x): |
| 5 | + x_min = torch.min(x) |
| 6 | + x_max = torch.max(x) |
| 7 | + norm = (x - x_min) / (x_max - x_min) |
| 8 | + return norm |
16 | 9 |
|
17 |
| - model_ascent.eval() |
18 |
| - losses_record = [] |
19 | 10 |
|
20 |
| - example_data_loader = DataLoader(dataset=poisoned_data, |
21 |
| - batch_size=1, |
22 |
| - shuffle=False, |
23 |
| - ) |
| 11 | +class ABLAnalysis(): |
| 12 | + def __init__(self): |
| 13 | + # Based on https://github.com/bboylyg/ABL/blob/main/backdoor_isolation.py |
| 14 | + return |
24 | 15 |
|
25 |
| - for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)): |
26 |
| - if opt.cuda: |
27 |
| - img = img.cuda() |
28 |
| - target = target.cuda() |
| 16 | + def compute_loss_value(self, data, model_ascent): |
| 17 | + # Calculate loss value per example |
| 18 | + # Define loss function |
| 19 | + if opt.cuda: |
| 20 | + criterion = nn.CrossEntropyLoss().cuda() |
| 21 | + else: |
| 22 | + criterion = nn.CrossEntropyLoss() |
29 | 23 |
|
30 |
| - with torch.no_grad(): |
31 |
| - output = model_ascent(img) |
32 |
| - loss = criterion(output, target) |
33 |
| - # print(loss.item()) |
| 24 | + model_ascent.eval() |
| 25 | + losses_record = [] |
34 | 26 |
|
35 |
| - losses_record.append(loss.item()) |
| 27 | + example_data_loader = DataLoader(dataset=poisoned_data, |
| 28 | + batch_size=1, |
| 29 | + shuffle=False, |
| 30 | + ) |
36 | 31 |
|
37 |
| - losses_idx = np.argsort(np.array(losses_record)) # get the index of examples by loss value in ascending order |
| 32 | + for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)): |
| 33 | + if opt.cuda: |
| 34 | + img = img.cuda() |
| 35 | + target = target.cuda() |
38 | 36 |
|
39 |
| - # Show the lowest 10 loss values |
40 |
| - losses_record_arr = np.array(losses_record) |
41 |
| - print('Top ten loss value:', losses_record_arr[losses_idx[:10]]) |
| 37 | + with torch.no_grad(): |
| 38 | + output = model_ascent(img) |
| 39 | + loss = criterion(output, target) |
| 40 | + # print(loss.item()) |
42 | 41 |
|
43 |
| - return losses_idx |
| 42 | + losses_record.append(loss.item()) |
44 | 43 |
|
| 44 | + losses_idx = np.argsort(np.array(losses_record)) # get the index of examples by loss value in ascending order |
45 | 45 |
|
46 |
| -def isolate_data(opt, poisoned_data, losses_idx): |
47 |
| - # Initialize lists |
48 |
| - other_examples = [] |
49 |
| - isolation_examples = [] |
| 46 | + # Show the lowest 10 loss values |
| 47 | + losses_record_arr = np.array(losses_record) |
| 48 | + print('Top ten loss value:', losses_record_arr[losses_idx[:10]]) |
50 | 49 |
|
51 |
| - cnt = 0 |
52 |
| - ratio = opt.isolation_ratio |
| 50 | + return losses_idx |
53 | 51 |
|
54 |
| - example_data_loader = DataLoader(dataset=poisoned_data, |
55 |
| - batch_size=1, |
56 |
| - shuffle=False, |
57 |
| - ) |
58 |
| - # print('full_poisoned_data_idx:', len(losses_idx)) |
59 |
| - perm = losses_idx[0: int(len(losses_idx) * ratio)] |
| 52 | + def isolate_data(self, data, losses_idx): |
| 53 | + # Initialize lists |
| 54 | + other_examples = [] |
| 55 | + isolation_examples = [] |
60 | 56 |
|
61 |
| - for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)): |
62 |
| - img = img.squeeze() |
63 |
| - target = target.squeeze() |
64 |
| - img = np.transpose((img * 255).cpu().numpy(), (1, 2, 0)).astype('uint8') |
65 |
| - target = target.cpu().numpy() |
| 57 | + cnt = 0 |
| 58 | + ratio = opt.isolation_ratio |
66 | 59 |
|
67 |
| - # Filter the examples corresponding to losses_idx |
68 |
| - if idx in perm: |
69 |
| - isolation_examples.append((img, target)) |
70 |
| - cnt += 1 |
71 |
| - else: |
72 |
| - other_examples.append((img, target)) |
| 60 | + example_data_loader = DataLoader(dataset=poisoned_data, |
| 61 | + batch_size=1, |
| 62 | + shuffle=False, |
| 63 | + ) |
| 64 | + # print('full_poisoned_data_idx:', len(losses_idx)) |
| 65 | + perm = losses_idx[0: int(len(losses_idx) * ratio)] |
73 | 66 |
|
74 |
| - # Save data |
75 |
| - if opt.save: |
76 |
| - data_path_isolation = os.path.join(opt.isolate_data_root, "{}_isolation{}%_examples.npy".format(opt.model_name, |
77 |
| - opt.isolation_ratio * 100)) |
78 |
| - data_path_other = os.path.join(opt.isolate_data_root, "{}_other{}%_examples.npy".format(opt.model_name, |
79 |
| - 100 - opt.isolation_ratio * 100)) |
80 |
| - if os.path.exists(data_path_isolation): |
81 |
| - raise ValueError('isolation data already exists') |
82 |
| - else: |
83 |
| - # save the isolation examples |
84 |
| - np.save(data_path_isolation, isolation_examples) |
85 |
| - np.save(data_path_other, other_examples) |
| 67 | + for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)): |
| 68 | + img = img.squeeze() |
| 69 | + target = target.squeeze() |
| 70 | + img = np.transpose((img * 255).cpu().numpy(), (1, 2, 0)).astype('uint8') |
| 71 | + target = target.cpu().numpy() |
86 | 72 |
|
87 |
| - print('Finish collecting {} isolation examples: '.format(len(isolation_examples))) |
88 |
| - print('Finish collecting {} other examples: '.format(len(other_examples))) |
| 73 | + # Filter the examples corresponding to losses_idx |
| 74 | + if idx in perm: |
| 75 | + isolation_examples.append((img, target)) |
| 76 | + cnt += 1 |
| 77 | + else: |
| 78 | + other_examples.append((img, target)) |
89 | 79 |
|
| 80 | + # Save data |
| 81 | + if opt.save: |
| 82 | + data_path_isolation = os.path.join(opt.isolate_data_root, |
| 83 | + "{}_isolation{}%_examples.npy".format(opt.model_name, |
| 84 | + opt.isolation_ratio * 100)) |
| 85 | + data_path_other = os.path.join(opt.isolate_data_root, "{}_other{}%_examples.npy".format(opt.model_name, |
| 86 | + 100 - opt.isolation_ratio * 100)) |
| 87 | + if os.path.exists(data_path_isolation): |
| 88 | + raise ValueError('isolation data already exists') |
| 89 | + else: |
| 90 | + # save the isolation examples |
| 91 | + np.save(data_path_isolation, isolation_examples) |
| 92 | + np.save(data_path_other, other_examples) |
90 | 93 |
|
91 |
| -def train_step(opt, train_loader, model_ascent, optimizer, criterion, epoch): |
92 |
| - losses = AverageMeter() |
93 |
| - top1 = AverageMeter() |
94 |
| - top5 = AverageMeter() |
| 94 | + print('Finish collecting {} isolation examples: '.format(len(isolation_examples))) |
| 95 | + print('Finish collecting {} other examples: '.format(len(other_examples))) |
95 | 96 |
|
96 |
| - model_ascent.train() |
97 | 97 |
|
98 |
| - for idx, (img, target) in enumerate(train_loader, start=1): |
99 |
| - if opt.cuda: |
100 |
| - img = img.cuda() |
101 |
| - target = target.cuda() |
102 | 98 |
|
103 |
| - if opt.gradient_ascent_type == 'LGA': |
104 |
| - output = model_ascent(img) |
105 |
| - loss = criterion(output, target) |
106 |
| - # add Local Gradient Ascent(LGA) loss |
107 |
| - loss_ascent = torch.sign(loss - opt.gamma) * loss |
108 | 99 |
|
109 |
| - elif opt.gradient_ascent_type == 'Flooding': |
110 |
| - output = model_ascent(img) |
111 |
| - # output = student(img) |
112 |
| - loss = criterion(output, target) |
113 |
| - # add flooding loss |
114 |
| - loss_ascent = (loss - opt.flooding).abs() + opt.flooding |
115 |
| - |
116 |
| - else: |
117 |
| - raise NotImplementedError |
118 |
| - |
119 |
| - prec1, prec5 = accuracy(output, target, topk=(1, 5)) |
120 |
| - losses.update(loss_ascent.item(), img.size(0)) |
121 |
| - top1.update(prec1.item(), img.size(0)) |
122 |
| - top5.update(prec5.item(), img.size(0)) |
123 |
| - |
124 |
| - optimizer.zero_grad() |
125 |
| - loss_ascent.backward() |
126 |
| - optimizer.step() |
127 |
| - |
128 |
| - if idx % opt.print_freq == 0: |
129 |
| - print('Epoch[{0}]:[{1:03}/{2:03}] ' |
130 |
| - 'Loss:{losses.val:.4f}({losses.avg:.4f}) ' |
131 |
| - 'Prec@1:{top1.val:.2f}({top1.avg:.2f}) ' |
132 |
| - 'Prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(epoch, idx, len(train_loader), losses=losses, top1=top1, top5=top5)) |
133 |
| - |
134 |
| - |
135 |
| -def test(opt, test_clean_loader, test_bad_loader, model_ascent, criterion, epoch): |
136 |
| - test_process = [] |
137 |
| - losses = AverageMeter() |
138 |
| - top1 = AverageMeter() |
139 |
| - top5 = AverageMeter() |
140 |
| - |
141 |
| - model_ascent.eval() |
142 |
| - |
143 |
| - for idx, (img, target) in enumerate(test_clean_loader, start=1): |
144 |
| - if opt.cuda: |
145 |
| - img = img.cuda() |
146 |
| - target = target.cuda() |
147 |
| - |
148 |
| - with torch.no_grad(): |
149 |
| - output = model_ascent(img) |
150 |
| - loss = criterion(output, target) |
151 |
| - |
152 |
| - prec1, prec5 = accuracy(output, target, topk=(1, 5)) |
153 |
| - losses.update(loss.item(), img.size(0)) |
154 |
| - top1.update(prec1.item(), img.size(0)) |
155 |
| - top5.update(prec5.item(), img.size(0)) |
156 |
| - |
157 |
| - acc_clean = [top1.avg, top5.avg, losses.avg] |
158 |
| - |
159 |
| - losses = AverageMeter() |
160 |
| - top1 = AverageMeter() |
161 |
| - top5 = AverageMeter() |
162 |
| - |
163 |
| - for idx, (img, target) in enumerate(test_bad_loader, start=1): |
164 |
| - if opt.cuda: |
165 |
| - img = img.cuda() |
166 |
| - target = target.cuda() |
167 |
| - |
168 |
| - with torch.no_grad(): |
169 |
| - output = model_ascent(img) |
170 |
| - loss = criterion(output, target) |
171 |
| - |
172 |
| - prec1, prec5 = accuracy(output, target, topk=(1, 5)) |
173 |
| - losses.update(loss.item(), img.size(0)) |
174 |
| - top1.update(prec1.item(), img.size(0)) |
175 |
| - top5.update(prec5.item(), img.size(0)) |
176 |
| - |
177 |
| - acc_bd = [top1.avg, top5.avg, losses.avg] |
178 |
| - |
179 |
| - print('[Clean] Prec@1: {:.2f}, Loss: {:.4f}'.format(acc_clean[0], acc_clean[2])) |
180 |
| - print('[Bad] Prec@1: {:.2f}, Loss: {:.4f}'.format(acc_bd[0], acc_bd[2])) |
181 |
| - |
182 |
| - # save training progress |
183 |
| - if epoch < opt.tuning_epochs + 1: |
184 |
| - log_root = opt.log_root + '/ABL_results_tuning_epochs.csv' |
185 |
| - test_process.append( |
186 |
| - (epoch, acc_clean[0], acc_bd[0], acc_clean[2], acc_bd[2])) |
187 |
| - df = pd.DataFrame(test_process, columns=("Epoch", "Test_clean_acc", "Test_bad_acc", |
188 |
| - "Test_clean_loss", "Test_bad_loss")) |
189 |
| - df.to_csv(log_root, mode='a', index=False, encoding='utf-8') |
190 |
| - |
191 |
| - return acc_clean, acc_bd |
192 |
| - |
193 |
| - |
194 |
| -def train(opt): |
195 |
| - # Load models |
196 |
| - print('----------- Network Initialization --------------') |
197 |
| - model_ascent, _ = select_model(dataset=opt.dataset, |
198 |
| - model_name=opt.model_name, |
199 |
| - pretrained=False, |
200 |
| - pretrained_models_path=opt.isolation_model_root, |
201 |
| - n_classes=opt.num_class) |
202 |
| - model_ascent.to(opt.device) |
203 |
| - print('finished model init...') |
204 |
| - |
205 |
| - # initialize optimizer |
206 |
| - optimizer = torch.optim.SGD(model_ascent.parameters(), |
207 |
| - lr=opt.lr, |
208 |
| - momentum=opt.momentum, |
209 |
| - weight_decay=opt.weight_decay, |
210 |
| - nesterov=True) |
211 |
| - |
212 |
| - # define loss functions |
213 |
| - if opt.cuda: |
214 |
| - criterion = nn.CrossEntropyLoss().cuda() |
215 |
| - else: |
216 |
| - criterion = nn.CrossEntropyLoss() |
217 |
| - |
218 |
| - print('----------- Data Initialization --------------') |
219 |
| - if opt.load_fixed_data: |
220 |
| - tf_compose = transforms.Compose([ |
221 |
| - transforms.ToTensor() |
222 |
| - ]) |
223 |
| - # load the fixed poisoned data, e.g. Dynamic, FC, DFST attacks etc. |
224 |
| - poisoned_data = np.load(opt.poisoned_data_path, allow_pickle=True) |
225 |
| - poisoned_data_loader = DataLoader(dataset=poisoned_data, |
226 |
| - batch_size=opt.batch_size, |
227 |
| - shuffle=True, |
228 |
| - ) |
229 |
| - else: |
230 |
| - poisoned_data, poisoned_data_loader = get_backdoor_loader(opt) |
231 |
| - |
232 |
| - test_clean_loader, test_bad_loader = get_test_loader(opt) |
233 |
| - |
234 |
| - print('----------- Train Initialization --------------') |
235 |
| - for epoch in range(0, opt.tuning_epochs): |
236 |
| - |
237 |
| - adjust_learning_rate(optimizer, epoch, opt) |
238 |
| - |
239 |
| - # train every epoch |
240 |
| - if epoch == 0: |
241 |
| - # before training test firstly |
242 |
| - test(opt, test_clean_loader, test_bad_loader, model_ascent, |
243 |
| - criterion, epoch + 1) |
244 |
| - |
245 |
| - train_step(opt, poisoned_data_loader, model_ascent, optimizer, criterion, epoch + 1) |
246 |
| - |
247 |
| - # evaluate on testing set |
248 |
| - print('testing the ascended model......') |
249 |
| - acc_clean, acc_bad = test(opt, test_clean_loader, test_bad_loader, model_ascent, criterion, epoch + 1) |
250 |
| - |
251 |
| - if opt.save: |
252 |
| - # remember best precision and save checkpoint |
253 |
| - # is_best = acc_clean[0] > opt.threshold_clean |
254 |
| - # opt.threshold_clean = min(acc_clean[0], opt.threshold_clean) |
255 |
| - # |
256 |
| - # best_clean_acc = acc_clean[0] |
257 |
| - # best_bad_acc = acc_bad[0] |
258 |
| - # |
259 |
| - # save_checkpoint({ |
260 |
| - # 'epoch': epoch, |
261 |
| - # 'state_dict': model_ascent.state_dict(), |
262 |
| - # 'clean_acc': best_clean_acc, |
263 |
| - # 'bad_acc': best_bad_acc, |
264 |
| - # 'optimizer': optimizer.state_dict(), |
265 |
| - # }, epoch, is_best, opt.checkpoint_root, opt.model_name) |
266 |
| - |
267 |
| - # save checkpoint at interval epoch |
268 |
| - if epoch % opt.interval == 0: |
269 |
| - is_best = True |
270 |
| - save_checkpoint({ |
271 |
| - 'epoch': epoch + 1, |
272 |
| - 'state_dict': model_ascent.state_dict(), |
273 |
| - 'clean_acc': acc_clean[0], |
274 |
| - 'bad_acc': acc_bad[0], |
275 |
| - 'optimizer': optimizer.state_dict(), |
276 |
| - }, epoch, is_best, opt) |
277 |
| - |
278 |
| - return poisoned_data, model_ascent |
279 |
| - |
280 |
| - |
281 |
| -def adjust_learning_rate(optimizer, epoch, opt): |
282 |
| - if epoch < opt.tuning_epochs: |
283 |
| - lr = opt.lr |
284 |
| - else: |
285 |
| - lr = 0.01 |
286 |
| - print('epoch: {} lr: {:.4f}'.format(epoch, lr)) |
287 |
| - for param_group in optimizer.param_groups: |
288 |
| - param_group['lr'] = lr |
289 |
| - |
290 |
| - |
291 |
| -def save_checkpoint(state, epoch, is_best, opt): |
292 |
| - if is_best: |
293 |
| - filepath = os.path.join(opt.save, opt.model_name + r'-tuning_epochs{}.tar'.format(epoch)) |
294 |
| - torch.save(state, filepath) |
295 |
| - print('[info] Finish saving the model') |
296 |
| - |
297 |
| -def main(): |
298 |
| - print('----------- Train isolated model -----------') |
299 |
| - opt = get_arguments().parse_args() |
300 |
| - poisoned_data, ascent_model = train(opt) |
301 |
| - |
302 |
| - print('----------- Calculate loss value per example -----------') |
303 |
| - losses_idx = compute_loss_value(opt, poisoned_data, ascent_model) |
304 |
| - |
305 |
| - print('----------- Collect isolation data -----------') |
306 |
| - isolate_data(opt, poisoned_data, losses_idx) |
307 |
| - |
308 |
| -if (__name__ == '__main__'): |
309 |
| - main() |
0 commit comments