Skip to content

Commit 3fd736c

Browse files
authored
Update backdoor_isolation.py
1 parent 8aadd54 commit 3fd736c

File tree

1 file changed

+285
-75
lines changed

1 file changed

+285
-75
lines changed

backdoor_isolation.py

+285-75
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,309 @@
1-
import torch
1+
from models.selector import *
2+
from utils.util import *
3+
from data_loader import *
4+
from torch.utils.data import DataLoader
5+
from config import get_arguments
6+
from tqdm import tqdm
27

38

4-
def min_max_normalization(x):
5-
x_min = torch.min(x)
6-
x_max = torch.max(x)
7-
norm = (x - x_min) / (x_max - x_min)
8-
return norm
9+
def compute_loss_value(opt, poisoned_data, model_ascent):
10+
# Calculate loss value per example
11+
# Define loss function
12+
if opt.cuda:
13+
criterion = nn.CrossEntropyLoss().cuda()
14+
else:
15+
criterion = nn.CrossEntropyLoss()
916

17+
model_ascent.eval()
18+
losses_record = []
1019

11-
class ABLAnalysis():
12-
def __init__(self):
13-
# Based on https://github.com/bboylyg/ABL/blob/main/backdoor_isolation.py
14-
return
20+
example_data_loader = DataLoader(dataset=poisoned_data,
21+
batch_size=1,
22+
shuffle=False,
23+
)
1524

16-
def compute_loss_value(self, data, model_ascent):
17-
# Calculate loss value per example
18-
# Define loss function
19-
if opt.cuda:
20-
criterion = nn.CrossEntropyLoss().cuda()
21-
else:
22-
criterion = nn.CrossEntropyLoss()
25+
for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)):
26+
if opt.cuda:
27+
img = img.cuda()
28+
target = target.cuda()
2329

24-
model_ascent.eval()
25-
losses_record = []
30+
with torch.no_grad():
31+
output = model_ascent(img)
32+
loss = criterion(output, target)
33+
# print(loss.item())
2634

27-
example_data_loader = DataLoader(dataset=poisoned_data,
28-
batch_size=1,
29-
shuffle=False,
30-
)
35+
losses_record.append(loss.item())
3136

32-
for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)):
33-
if opt.cuda:
34-
img = img.cuda()
35-
target = target.cuda()
37+
losses_idx = np.argsort(np.array(losses_record)) # get the index of examples by loss value in ascending order
3638

37-
with torch.no_grad():
38-
output = model_ascent(img)
39-
loss = criterion(output, target)
40-
# print(loss.item())
39+
# Show the lowest 10 loss values
40+
losses_record_arr = np.array(losses_record)
41+
print('Top ten loss value:', losses_record_arr[losses_idx[:10]])
4142

42-
losses_record.append(loss.item())
43+
return losses_idx
4344

44-
losses_idx = np.argsort(np.array(losses_record)) # get the index of examples by loss value in ascending order
4545

46-
# Show the lowest 10 loss values
47-
losses_record_arr = np.array(losses_record)
48-
print('Top ten loss value:', losses_record_arr[losses_idx[:10]])
46+
def isolate_data(opt, poisoned_data, losses_idx):
47+
# Initialize lists
48+
other_examples = []
49+
isolation_examples = []
4950

50-
return losses_idx
51+
cnt = 0
52+
ratio = opt.isolation_ratio
5153

52-
def isolate_data(self, data, losses_idx):
53-
# Initialize lists
54-
other_examples = []
55-
isolation_examples = []
54+
example_data_loader = DataLoader(dataset=poisoned_data,
55+
batch_size=1,
56+
shuffle=False,
57+
)
58+
# print('full_poisoned_data_idx:', len(losses_idx))
59+
perm = losses_idx[0: int(len(losses_idx) * ratio)]
5660

57-
cnt = 0
58-
ratio = opt.isolation_ratio
61+
for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)):
62+
img = img.squeeze()
63+
target = target.squeeze()
64+
img = np.transpose((img * 255).cpu().numpy(), (1, 2, 0)).astype('uint8')
65+
target = target.cpu().numpy()
5966

60-
example_data_loader = DataLoader(dataset=poisoned_data,
61-
batch_size=1,
62-
shuffle=False,
63-
)
64-
# print('full_poisoned_data_idx:', len(losses_idx))
65-
perm = losses_idx[0: int(len(losses_idx) * ratio)]
67+
# Filter the examples corresponding to losses_idx
68+
if idx in perm:
69+
isolation_examples.append((img, target))
70+
cnt += 1
71+
else:
72+
other_examples.append((img, target))
6673

67-
for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)):
68-
img = img.squeeze()
69-
target = target.squeeze()
70-
img = np.transpose((img * 255).cpu().numpy(), (1, 2, 0)).astype('uint8')
71-
target = target.cpu().numpy()
74+
# Save data
75+
if opt.save:
76+
data_path_isolation = os.path.join(opt.isolate_data_root, "{}_isolation{}%_examples.npy".format(opt.model_name,
77+
opt.isolation_ratio * 100))
78+
data_path_other = os.path.join(opt.isolate_data_root, "{}_other{}%_examples.npy".format(opt.model_name,
79+
100 - opt.isolation_ratio * 100))
80+
if os.path.exists(data_path_isolation):
81+
raise ValueError('isolation data already exists')
82+
else:
83+
# save the isolation examples
84+
np.save(data_path_isolation, isolation_examples)
85+
np.save(data_path_other, other_examples)
7286

73-
# Filter the examples corresponding to losses_idx
74-
if idx in perm:
75-
isolation_examples.append((img, target))
76-
cnt += 1
77-
else:
78-
other_examples.append((img, target))
87+
print('Finish collecting {} isolation examples: '.format(len(isolation_examples)))
88+
print('Finish collecting {} other examples: '.format(len(other_examples)))
7989

80-
# Save data
81-
if opt.save:
82-
data_path_isolation = os.path.join(opt.isolate_data_root,
83-
"{}_isolation{}%_examples.npy".format(opt.model_name,
84-
opt.isolation_ratio * 100))
85-
data_path_other = os.path.join(opt.isolate_data_root, "{}_other{}%_examples.npy".format(opt.model_name,
86-
100 - opt.isolation_ratio * 100))
87-
if os.path.exists(data_path_isolation):
88-
raise ValueError('isolation data already exists')
89-
else:
90-
# save the isolation examples
91-
np.save(data_path_isolation, isolation_examples)
92-
np.save(data_path_other, other_examples)
9390

94-
print('Finish collecting {} isolation examples: '.format(len(isolation_examples)))
95-
print('Finish collecting {} other examples: '.format(len(other_examples)))
91+
def train_step(opt, train_loader, model_ascent, optimizer, criterion, epoch):
92+
losses = AverageMeter()
93+
top1 = AverageMeter()
94+
top5 = AverageMeter()
9695

96+
model_ascent.train()
9797

98+
for idx, (img, target) in enumerate(train_loader, start=1):
99+
if opt.cuda:
100+
img = img.cuda()
101+
target = target.cuda()
98102

103+
if opt.gradient_ascent_type == 'LGA':
104+
output = model_ascent(img)
105+
loss = criterion(output, target)
106+
# add Local Gradient Ascent(LGA) loss
107+
loss_ascent = torch.sign(loss - opt.gamma) * loss
99108

109+
elif opt.gradient_ascent_type == 'Flooding':
110+
output = model_ascent(img)
111+
# output = student(img)
112+
loss = criterion(output, target)
113+
# add flooding loss
114+
loss_ascent = (loss - opt.flooding).abs() + opt.flooding
115+
116+
else:
117+
raise NotImplementedError
118+
119+
prec1, prec5 = accuracy(output, target, topk=(1, 5))
120+
losses.update(loss_ascent.item(), img.size(0))
121+
top1.update(prec1.item(), img.size(0))
122+
top5.update(prec5.item(), img.size(0))
123+
124+
optimizer.zero_grad()
125+
loss_ascent.backward()
126+
optimizer.step()
127+
128+
if idx % opt.print_freq == 0:
129+
print('Epoch[{0}]:[{1:03}/{2:03}] '
130+
'Loss:{losses.val:.4f}({losses.avg:.4f}) '
131+
'Prec@1:{top1.val:.2f}({top1.avg:.2f}) '
132+
'Prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(epoch, idx, len(train_loader), losses=losses, top1=top1, top5=top5))
133+
134+
135+
def test(opt, test_clean_loader, test_bad_loader, model_ascent, criterion, epoch):
136+
test_process = []
137+
losses = AverageMeter()
138+
top1 = AverageMeter()
139+
top5 = AverageMeter()
140+
141+
model_ascent.eval()
142+
143+
for idx, (img, target) in enumerate(test_clean_loader, start=1):
144+
if opt.cuda:
145+
img = img.cuda()
146+
target = target.cuda()
147+
148+
with torch.no_grad():
149+
output = model_ascent(img)
150+
loss = criterion(output, target)
151+
152+
prec1, prec5 = accuracy(output, target, topk=(1, 5))
153+
losses.update(loss.item(), img.size(0))
154+
top1.update(prec1.item(), img.size(0))
155+
top5.update(prec5.item(), img.size(0))
156+
157+
acc_clean = [top1.avg, top5.avg, losses.avg]
158+
159+
losses = AverageMeter()
160+
top1 = AverageMeter()
161+
top5 = AverageMeter()
162+
163+
for idx, (img, target) in enumerate(test_bad_loader, start=1):
164+
if opt.cuda:
165+
img = img.cuda()
166+
target = target.cuda()
167+
168+
with torch.no_grad():
169+
output = model_ascent(img)
170+
loss = criterion(output, target)
171+
172+
prec1, prec5 = accuracy(output, target, topk=(1, 5))
173+
losses.update(loss.item(), img.size(0))
174+
top1.update(prec1.item(), img.size(0))
175+
top5.update(prec5.item(), img.size(0))
176+
177+
acc_bd = [top1.avg, top5.avg, losses.avg]
178+
179+
print('[Clean] Prec@1: {:.2f}, Loss: {:.4f}'.format(acc_clean[0], acc_clean[2]))
180+
print('[Bad] Prec@1: {:.2f}, Loss: {:.4f}'.format(acc_bd[0], acc_bd[2]))
181+
182+
# save training progress
183+
if epoch < opt.tuning_epochs + 1:
184+
log_root = opt.log_root + '/ABL_results_tuning_epochs.csv'
185+
test_process.append(
186+
(epoch, acc_clean[0], acc_bd[0], acc_clean[2], acc_bd[2]))
187+
df = pd.DataFrame(test_process, columns=("Epoch", "Test_clean_acc", "Test_bad_acc",
188+
"Test_clean_loss", "Test_bad_loss"))
189+
df.to_csv(log_root, mode='a', index=False, encoding='utf-8')
190+
191+
return acc_clean, acc_bd
192+
193+
194+
def train(opt):
195+
# Load models
196+
print('----------- Network Initialization --------------')
197+
model_ascent, _ = select_model(dataset=opt.dataset,
198+
model_name=opt.model_name,
199+
pretrained=False,
200+
pretrained_models_path=opt.isolation_model_root,
201+
n_classes=opt.num_class)
202+
model_ascent.to(opt.device)
203+
print('finished model init...')
204+
205+
# initialize optimizer
206+
optimizer = torch.optim.SGD(model_ascent.parameters(),
207+
lr=opt.lr,
208+
momentum=opt.momentum,
209+
weight_decay=opt.weight_decay,
210+
nesterov=True)
211+
212+
# define loss functions
213+
if opt.cuda:
214+
criterion = nn.CrossEntropyLoss().cuda()
215+
else:
216+
criterion = nn.CrossEntropyLoss()
217+
218+
print('----------- Data Initialization --------------')
219+
if opt.load_fixed_data:
220+
tf_compose = transforms.Compose([
221+
transforms.ToTensor()
222+
])
223+
# load the fixed poisoned data, e.g. Dynamic, FC, DFST attacks etc.
224+
poisoned_data = np.load(opt.poisoned_data_path, allow_pickle=True)
225+
poisoned_data_loader = DataLoader(dataset=poisoned_data,
226+
batch_size=opt.batch_size,
227+
shuffle=True,
228+
)
229+
else:
230+
poisoned_data, poisoned_data_loader = get_backdoor_loader(opt)
231+
232+
test_clean_loader, test_bad_loader = get_test_loader(opt)
233+
234+
print('----------- Train Initialization --------------')
235+
for epoch in range(0, opt.tuning_epochs):
236+
237+
adjust_learning_rate(optimizer, epoch, opt)
238+
239+
# train every epoch
240+
if epoch == 0:
241+
# before training test firstly
242+
test(opt, test_clean_loader, test_bad_loader, model_ascent,
243+
criterion, epoch + 1)
244+
245+
train_step(opt, poisoned_data_loader, model_ascent, optimizer, criterion, epoch + 1)
246+
247+
# evaluate on testing set
248+
print('testing the ascended model......')
249+
acc_clean, acc_bad = test(opt, test_clean_loader, test_bad_loader, model_ascent, criterion, epoch + 1)
250+
251+
if opt.save:
252+
# remember best precision and save checkpoint
253+
# is_best = acc_clean[0] > opt.threshold_clean
254+
# opt.threshold_clean = min(acc_clean[0], opt.threshold_clean)
255+
#
256+
# best_clean_acc = acc_clean[0]
257+
# best_bad_acc = acc_bad[0]
258+
#
259+
# save_checkpoint({
260+
# 'epoch': epoch,
261+
# 'state_dict': model_ascent.state_dict(),
262+
# 'clean_acc': best_clean_acc,
263+
# 'bad_acc': best_bad_acc,
264+
# 'optimizer': optimizer.state_dict(),
265+
# }, epoch, is_best, opt.checkpoint_root, opt.model_name)
266+
267+
# save checkpoint at interval epoch
268+
if epoch % opt.interval == 0:
269+
is_best = True
270+
save_checkpoint({
271+
'epoch': epoch + 1,
272+
'state_dict': model_ascent.state_dict(),
273+
'clean_acc': acc_clean[0],
274+
'bad_acc': acc_bad[0],
275+
'optimizer': optimizer.state_dict(),
276+
}, epoch, is_best, opt)
277+
278+
return poisoned_data, model_ascent
279+
280+
281+
def adjust_learning_rate(optimizer, epoch, opt):
282+
if epoch < opt.tuning_epochs:
283+
lr = opt.lr
284+
else:
285+
lr = 0.01
286+
print('epoch: {} lr: {:.4f}'.format(epoch, lr))
287+
for param_group in optimizer.param_groups:
288+
param_group['lr'] = lr
289+
290+
291+
def save_checkpoint(state, epoch, is_best, opt):
292+
if is_best:
293+
filepath = os.path.join(opt.save, opt.model_name + r'-tuning_epochs{}.tar'.format(epoch))
294+
torch.save(state, filepath)
295+
print('[info] Finish saving the model')
296+
297+
def main():
298+
print('----------- Train isolated model -----------')
299+
opt = get_arguments().parse_args()
300+
poisoned_data, ascent_model = train(opt)
301+
302+
print('----------- Calculate loss value per example -----------')
303+
losses_idx = compute_loss_value(opt, poisoned_data, ascent_model)
304+
305+
print('----------- Collect isolation data -----------')
306+
isolate_data(opt, poisoned_data, losses_idx)
307+
308+
if (__name__ == '__main__'):
309+
main()

0 commit comments

Comments
 (0)