|
1 |
| -import torch |
| 1 | +from models.selector import * |
| 2 | +from utils.util import * |
| 3 | +from data_loader import * |
| 4 | +from torch.utils.data import DataLoader |
| 5 | +from config import get_arguments |
| 6 | +from tqdm import tqdm |
2 | 7 |
|
3 | 8 |
|
4 |
| -def min_max_normalization(x): |
5 |
| - x_min = torch.min(x) |
6 |
| - x_max = torch.max(x) |
7 |
| - norm = (x - x_min) / (x_max - x_min) |
8 |
| - return norm |
| 9 | +def compute_loss_value(opt, poisoned_data, model_ascent): |
| 10 | + # Calculate loss value per example |
| 11 | + # Define loss function |
| 12 | + if opt.cuda: |
| 13 | + criterion = nn.CrossEntropyLoss().cuda() |
| 14 | + else: |
| 15 | + criterion = nn.CrossEntropyLoss() |
9 | 16 |
|
| 17 | + model_ascent.eval() |
| 18 | + losses_record = [] |
10 | 19 |
|
11 |
| -class ABLAnalysis(): |
12 |
| - def __init__(self): |
13 |
| - # Based on https://github.com/bboylyg/ABL/blob/main/backdoor_isolation.py |
14 |
| - return |
| 20 | + example_data_loader = DataLoader(dataset=poisoned_data, |
| 21 | + batch_size=1, |
| 22 | + shuffle=False, |
| 23 | + ) |
15 | 24 |
|
16 |
| - def compute_loss_value(self, data, model_ascent): |
17 |
| - # Calculate loss value per example |
18 |
| - # Define loss function |
19 |
| - if opt.cuda: |
20 |
| - criterion = nn.CrossEntropyLoss().cuda() |
21 |
| - else: |
22 |
| - criterion = nn.CrossEntropyLoss() |
| 25 | + for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)): |
| 26 | + if opt.cuda: |
| 27 | + img = img.cuda() |
| 28 | + target = target.cuda() |
23 | 29 |
|
24 |
| - model_ascent.eval() |
25 |
| - losses_record = [] |
| 30 | + with torch.no_grad(): |
| 31 | + output = model_ascent(img) |
| 32 | + loss = criterion(output, target) |
| 33 | + # print(loss.item()) |
26 | 34 |
|
27 |
| - example_data_loader = DataLoader(dataset=poisoned_data, |
28 |
| - batch_size=1, |
29 |
| - shuffle=False, |
30 |
| - ) |
| 35 | + losses_record.append(loss.item()) |
31 | 36 |
|
32 |
| - for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)): |
33 |
| - if opt.cuda: |
34 |
| - img = img.cuda() |
35 |
| - target = target.cuda() |
| 37 | + losses_idx = np.argsort(np.array(losses_record)) # get the index of examples by loss value in ascending order |
36 | 38 |
|
37 |
| - with torch.no_grad(): |
38 |
| - output = model_ascent(img) |
39 |
| - loss = criterion(output, target) |
40 |
| - # print(loss.item()) |
| 39 | + # Show the lowest 10 loss values |
| 40 | + losses_record_arr = np.array(losses_record) |
| 41 | + print('Top ten loss value:', losses_record_arr[losses_idx[:10]]) |
41 | 42 |
|
42 |
| - losses_record.append(loss.item()) |
| 43 | + return losses_idx |
43 | 44 |
|
44 |
| - losses_idx = np.argsort(np.array(losses_record)) # get the index of examples by loss value in ascending order |
45 | 45 |
|
46 |
| - # Show the lowest 10 loss values |
47 |
| - losses_record_arr = np.array(losses_record) |
48 |
| - print('Top ten loss value:', losses_record_arr[losses_idx[:10]]) |
| 46 | +def isolate_data(opt, poisoned_data, losses_idx): |
| 47 | + # Initialize lists |
| 48 | + other_examples = [] |
| 49 | + isolation_examples = [] |
49 | 50 |
|
50 |
| - return losses_idx |
| 51 | + cnt = 0 |
| 52 | + ratio = opt.isolation_ratio |
51 | 53 |
|
52 |
| - def isolate_data(self, data, losses_idx): |
53 |
| - # Initialize lists |
54 |
| - other_examples = [] |
55 |
| - isolation_examples = [] |
| 54 | + example_data_loader = DataLoader(dataset=poisoned_data, |
| 55 | + batch_size=1, |
| 56 | + shuffle=False, |
| 57 | + ) |
| 58 | + # print('full_poisoned_data_idx:', len(losses_idx)) |
| 59 | + perm = losses_idx[0: int(len(losses_idx) * ratio)] |
56 | 60 |
|
57 |
| - cnt = 0 |
58 |
| - ratio = opt.isolation_ratio |
| 61 | + for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)): |
| 62 | + img = img.squeeze() |
| 63 | + target = target.squeeze() |
| 64 | + img = np.transpose((img * 255).cpu().numpy(), (1, 2, 0)).astype('uint8') |
| 65 | + target = target.cpu().numpy() |
59 | 66 |
|
60 |
| - example_data_loader = DataLoader(dataset=poisoned_data, |
61 |
| - batch_size=1, |
62 |
| - shuffle=False, |
63 |
| - ) |
64 |
| - # print('full_poisoned_data_idx:', len(losses_idx)) |
65 |
| - perm = losses_idx[0: int(len(losses_idx) * ratio)] |
| 67 | + # Filter the examples corresponding to losses_idx |
| 68 | + if idx in perm: |
| 69 | + isolation_examples.append((img, target)) |
| 70 | + cnt += 1 |
| 71 | + else: |
| 72 | + other_examples.append((img, target)) |
66 | 73 |
|
67 |
| - for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)): |
68 |
| - img = img.squeeze() |
69 |
| - target = target.squeeze() |
70 |
| - img = np.transpose((img * 255).cpu().numpy(), (1, 2, 0)).astype('uint8') |
71 |
| - target = target.cpu().numpy() |
| 74 | + # Save data |
| 75 | + if opt.save: |
| 76 | + data_path_isolation = os.path.join(opt.isolate_data_root, "{}_isolation{}%_examples.npy".format(opt.model_name, |
| 77 | + opt.isolation_ratio * 100)) |
| 78 | + data_path_other = os.path.join(opt.isolate_data_root, "{}_other{}%_examples.npy".format(opt.model_name, |
| 79 | + 100 - opt.isolation_ratio * 100)) |
| 80 | + if os.path.exists(data_path_isolation): |
| 81 | + raise ValueError('isolation data already exists') |
| 82 | + else: |
| 83 | + # save the isolation examples |
| 84 | + np.save(data_path_isolation, isolation_examples) |
| 85 | + np.save(data_path_other, other_examples) |
72 | 86 |
|
73 |
| - # Filter the examples corresponding to losses_idx |
74 |
| - if idx in perm: |
75 |
| - isolation_examples.append((img, target)) |
76 |
| - cnt += 1 |
77 |
| - else: |
78 |
| - other_examples.append((img, target)) |
| 87 | + print('Finish collecting {} isolation examples: '.format(len(isolation_examples))) |
| 88 | + print('Finish collecting {} other examples: '.format(len(other_examples))) |
79 | 89 |
|
80 |
| - # Save data |
81 |
| - if opt.save: |
82 |
| - data_path_isolation = os.path.join(opt.isolate_data_root, |
83 |
| - "{}_isolation{}%_examples.npy".format(opt.model_name, |
84 |
| - opt.isolation_ratio * 100)) |
85 |
| - data_path_other = os.path.join(opt.isolate_data_root, "{}_other{}%_examples.npy".format(opt.model_name, |
86 |
| - 100 - opt.isolation_ratio * 100)) |
87 |
| - if os.path.exists(data_path_isolation): |
88 |
| - raise ValueError('isolation data already exists') |
89 |
| - else: |
90 |
| - # save the isolation examples |
91 |
| - np.save(data_path_isolation, isolation_examples) |
92 |
| - np.save(data_path_other, other_examples) |
93 | 90 |
|
94 |
| - print('Finish collecting {} isolation examples: '.format(len(isolation_examples))) |
95 |
| - print('Finish collecting {} other examples: '.format(len(other_examples))) |
| 91 | +def train_step(opt, train_loader, model_ascent, optimizer, criterion, epoch): |
| 92 | + losses = AverageMeter() |
| 93 | + top1 = AverageMeter() |
| 94 | + top5 = AverageMeter() |
96 | 95 |
|
| 96 | + model_ascent.train() |
97 | 97 |
|
| 98 | + for idx, (img, target) in enumerate(train_loader, start=1): |
| 99 | + if opt.cuda: |
| 100 | + img = img.cuda() |
| 101 | + target = target.cuda() |
98 | 102 |
|
| 103 | + if opt.gradient_ascent_type == 'LGA': |
| 104 | + output = model_ascent(img) |
| 105 | + loss = criterion(output, target) |
| 106 | + # add Local Gradient Ascent(LGA) loss |
| 107 | + loss_ascent = torch.sign(loss - opt.gamma) * loss |
99 | 108 |
|
| 109 | + elif opt.gradient_ascent_type == 'Flooding': |
| 110 | + output = model_ascent(img) |
| 111 | + # output = student(img) |
| 112 | + loss = criterion(output, target) |
| 113 | + # add flooding loss |
| 114 | + loss_ascent = (loss - opt.flooding).abs() + opt.flooding |
| 115 | + |
| 116 | + else: |
| 117 | + raise NotImplementedError |
| 118 | + |
| 119 | + prec1, prec5 = accuracy(output, target, topk=(1, 5)) |
| 120 | + losses.update(loss_ascent.item(), img.size(0)) |
| 121 | + top1.update(prec1.item(), img.size(0)) |
| 122 | + top5.update(prec5.item(), img.size(0)) |
| 123 | + |
| 124 | + optimizer.zero_grad() |
| 125 | + loss_ascent.backward() |
| 126 | + optimizer.step() |
| 127 | + |
| 128 | + if idx % opt.print_freq == 0: |
| 129 | + print('Epoch[{0}]:[{1:03}/{2:03}] ' |
| 130 | + 'Loss:{losses.val:.4f}({losses.avg:.4f}) ' |
| 131 | + 'Prec@1:{top1.val:.2f}({top1.avg:.2f}) ' |
| 132 | + 'Prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(epoch, idx, len(train_loader), losses=losses, top1=top1, top5=top5)) |
| 133 | + |
| 134 | + |
| 135 | +def test(opt, test_clean_loader, test_bad_loader, model_ascent, criterion, epoch): |
| 136 | + test_process = [] |
| 137 | + losses = AverageMeter() |
| 138 | + top1 = AverageMeter() |
| 139 | + top5 = AverageMeter() |
| 140 | + |
| 141 | + model_ascent.eval() |
| 142 | + |
| 143 | + for idx, (img, target) in enumerate(test_clean_loader, start=1): |
| 144 | + if opt.cuda: |
| 145 | + img = img.cuda() |
| 146 | + target = target.cuda() |
| 147 | + |
| 148 | + with torch.no_grad(): |
| 149 | + output = model_ascent(img) |
| 150 | + loss = criterion(output, target) |
| 151 | + |
| 152 | + prec1, prec5 = accuracy(output, target, topk=(1, 5)) |
| 153 | + losses.update(loss.item(), img.size(0)) |
| 154 | + top1.update(prec1.item(), img.size(0)) |
| 155 | + top5.update(prec5.item(), img.size(0)) |
| 156 | + |
| 157 | + acc_clean = [top1.avg, top5.avg, losses.avg] |
| 158 | + |
| 159 | + losses = AverageMeter() |
| 160 | + top1 = AverageMeter() |
| 161 | + top5 = AverageMeter() |
| 162 | + |
| 163 | + for idx, (img, target) in enumerate(test_bad_loader, start=1): |
| 164 | + if opt.cuda: |
| 165 | + img = img.cuda() |
| 166 | + target = target.cuda() |
| 167 | + |
| 168 | + with torch.no_grad(): |
| 169 | + output = model_ascent(img) |
| 170 | + loss = criterion(output, target) |
| 171 | + |
| 172 | + prec1, prec5 = accuracy(output, target, topk=(1, 5)) |
| 173 | + losses.update(loss.item(), img.size(0)) |
| 174 | + top1.update(prec1.item(), img.size(0)) |
| 175 | + top5.update(prec5.item(), img.size(0)) |
| 176 | + |
| 177 | + acc_bd = [top1.avg, top5.avg, losses.avg] |
| 178 | + |
| 179 | + print('[Clean] Prec@1: {:.2f}, Loss: {:.4f}'.format(acc_clean[0], acc_clean[2])) |
| 180 | + print('[Bad] Prec@1: {:.2f}, Loss: {:.4f}'.format(acc_bd[0], acc_bd[2])) |
| 181 | + |
| 182 | + # save training progress |
| 183 | + if epoch < opt.tuning_epochs + 1: |
| 184 | + log_root = opt.log_root + '/ABL_results_tuning_epochs.csv' |
| 185 | + test_process.append( |
| 186 | + (epoch, acc_clean[0], acc_bd[0], acc_clean[2], acc_bd[2])) |
| 187 | + df = pd.DataFrame(test_process, columns=("Epoch", "Test_clean_acc", "Test_bad_acc", |
| 188 | + "Test_clean_loss", "Test_bad_loss")) |
| 189 | + df.to_csv(log_root, mode='a', index=False, encoding='utf-8') |
| 190 | + |
| 191 | + return acc_clean, acc_bd |
| 192 | + |
| 193 | + |
| 194 | +def train(opt): |
| 195 | + # Load models |
| 196 | + print('----------- Network Initialization --------------') |
| 197 | + model_ascent, _ = select_model(dataset=opt.dataset, |
| 198 | + model_name=opt.model_name, |
| 199 | + pretrained=False, |
| 200 | + pretrained_models_path=opt.isolation_model_root, |
| 201 | + n_classes=opt.num_class) |
| 202 | + model_ascent.to(opt.device) |
| 203 | + print('finished model init...') |
| 204 | + |
| 205 | + # initialize optimizer |
| 206 | + optimizer = torch.optim.SGD(model_ascent.parameters(), |
| 207 | + lr=opt.lr, |
| 208 | + momentum=opt.momentum, |
| 209 | + weight_decay=opt.weight_decay, |
| 210 | + nesterov=True) |
| 211 | + |
| 212 | + # define loss functions |
| 213 | + if opt.cuda: |
| 214 | + criterion = nn.CrossEntropyLoss().cuda() |
| 215 | + else: |
| 216 | + criterion = nn.CrossEntropyLoss() |
| 217 | + |
| 218 | + print('----------- Data Initialization --------------') |
| 219 | + if opt.load_fixed_data: |
| 220 | + tf_compose = transforms.Compose([ |
| 221 | + transforms.ToTensor() |
| 222 | + ]) |
| 223 | + # load the fixed poisoned data, e.g. Dynamic, FC, DFST attacks etc. |
| 224 | + poisoned_data = np.load(opt.poisoned_data_path, allow_pickle=True) |
| 225 | + poisoned_data_loader = DataLoader(dataset=poisoned_data, |
| 226 | + batch_size=opt.batch_size, |
| 227 | + shuffle=True, |
| 228 | + ) |
| 229 | + else: |
| 230 | + poisoned_data, poisoned_data_loader = get_backdoor_loader(opt) |
| 231 | + |
| 232 | + test_clean_loader, test_bad_loader = get_test_loader(opt) |
| 233 | + |
| 234 | + print('----------- Train Initialization --------------') |
| 235 | + for epoch in range(0, opt.tuning_epochs): |
| 236 | + |
| 237 | + adjust_learning_rate(optimizer, epoch, opt) |
| 238 | + |
| 239 | + # train every epoch |
| 240 | + if epoch == 0: |
| 241 | + # before training test firstly |
| 242 | + test(opt, test_clean_loader, test_bad_loader, model_ascent, |
| 243 | + criterion, epoch + 1) |
| 244 | + |
| 245 | + train_step(opt, poisoned_data_loader, model_ascent, optimizer, criterion, epoch + 1) |
| 246 | + |
| 247 | + # evaluate on testing set |
| 248 | + print('testing the ascended model......') |
| 249 | + acc_clean, acc_bad = test(opt, test_clean_loader, test_bad_loader, model_ascent, criterion, epoch + 1) |
| 250 | + |
| 251 | + if opt.save: |
| 252 | + # remember best precision and save checkpoint |
| 253 | + # is_best = acc_clean[0] > opt.threshold_clean |
| 254 | + # opt.threshold_clean = min(acc_clean[0], opt.threshold_clean) |
| 255 | + # |
| 256 | + # best_clean_acc = acc_clean[0] |
| 257 | + # best_bad_acc = acc_bad[0] |
| 258 | + # |
| 259 | + # save_checkpoint({ |
| 260 | + # 'epoch': epoch, |
| 261 | + # 'state_dict': model_ascent.state_dict(), |
| 262 | + # 'clean_acc': best_clean_acc, |
| 263 | + # 'bad_acc': best_bad_acc, |
| 264 | + # 'optimizer': optimizer.state_dict(), |
| 265 | + # }, epoch, is_best, opt.checkpoint_root, opt.model_name) |
| 266 | + |
| 267 | + # save checkpoint at interval epoch |
| 268 | + if epoch % opt.interval == 0: |
| 269 | + is_best = True |
| 270 | + save_checkpoint({ |
| 271 | + 'epoch': epoch + 1, |
| 272 | + 'state_dict': model_ascent.state_dict(), |
| 273 | + 'clean_acc': acc_clean[0], |
| 274 | + 'bad_acc': acc_bad[0], |
| 275 | + 'optimizer': optimizer.state_dict(), |
| 276 | + }, epoch, is_best, opt) |
| 277 | + |
| 278 | + return poisoned_data, model_ascent |
| 279 | + |
| 280 | + |
| 281 | +def adjust_learning_rate(optimizer, epoch, opt): |
| 282 | + if epoch < opt.tuning_epochs: |
| 283 | + lr = opt.lr |
| 284 | + else: |
| 285 | + lr = 0.01 |
| 286 | + print('epoch: {} lr: {:.4f}'.format(epoch, lr)) |
| 287 | + for param_group in optimizer.param_groups: |
| 288 | + param_group['lr'] = lr |
| 289 | + |
| 290 | + |
| 291 | +def save_checkpoint(state, epoch, is_best, opt): |
| 292 | + if is_best: |
| 293 | + filepath = os.path.join(opt.save, opt.model_name + r'-tuning_epochs{}.tar'.format(epoch)) |
| 294 | + torch.save(state, filepath) |
| 295 | + print('[info] Finish saving the model') |
| 296 | + |
| 297 | +def main(): |
| 298 | + print('----------- Train isolated model -----------') |
| 299 | + opt = get_arguments().parse_args() |
| 300 | + poisoned_data, ascent_model = train(opt) |
| 301 | + |
| 302 | + print('----------- Calculate loss value per example -----------') |
| 303 | + losses_idx = compute_loss_value(opt, poisoned_data, ascent_model) |
| 304 | + |
| 305 | + print('----------- Collect isolation data -----------') |
| 306 | + isolate_data(opt, poisoned_data, losses_idx) |
| 307 | + |
| 308 | +if (__name__ == '__main__'): |
| 309 | + main() |
0 commit comments