Skip to content

Commit 8aadd54

Browse files
authored
Update backdoor_isolation.py
1 parent 8cff2fe commit 8aadd54

File tree

1 file changed

+75
-285
lines changed

1 file changed

+75
-285
lines changed

backdoor_isolation.py

Lines changed: 75 additions & 285 deletions
Original file line numberDiff line numberDiff line change
@@ -1,309 +1,99 @@
1-
from models.selector import *
2-
from utils.util import *
3-
from data_loader import *
4-
from torch.utils.data import DataLoader
5-
from config import get_arguments
6-
from tqdm import tqdm
1+
import torch
72

83

9-
def compute_loss_value(opt, poisoned_data, model_ascent):
10-
# Calculate loss value per example
11-
# Define loss function
12-
if opt.cuda:
13-
criterion = nn.CrossEntropyLoss().cuda()
14-
else:
15-
criterion = nn.CrossEntropyLoss()
4+
def min_max_normalization(x):
5+
x_min = torch.min(x)
6+
x_max = torch.max(x)
7+
norm = (x - x_min) / (x_max - x_min)
8+
return norm
169

17-
model_ascent.eval()
18-
losses_record = []
1910

20-
example_data_loader = DataLoader(dataset=poisoned_data,
21-
batch_size=1,
22-
shuffle=False,
23-
)
11+
class ABLAnalysis():
12+
def __init__(self):
13+
# Based on https://github.com/bboylyg/ABL/blob/main/backdoor_isolation.py
14+
return
2415

25-
for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)):
26-
if opt.cuda:
27-
img = img.cuda()
28-
target = target.cuda()
16+
def compute_loss_value(self, data, model_ascent):
17+
# Calculate loss value per example
18+
# Define loss function
19+
if opt.cuda:
20+
criterion = nn.CrossEntropyLoss().cuda()
21+
else:
22+
criterion = nn.CrossEntropyLoss()
2923

30-
with torch.no_grad():
31-
output = model_ascent(img)
32-
loss = criterion(output, target)
33-
# print(loss.item())
24+
model_ascent.eval()
25+
losses_record = []
3426

35-
losses_record.append(loss.item())
27+
example_data_loader = DataLoader(dataset=poisoned_data,
28+
batch_size=1,
29+
shuffle=False,
30+
)
3631

37-
losses_idx = np.argsort(np.array(losses_record)) # get the index of examples by loss value in ascending order
32+
for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)):
33+
if opt.cuda:
34+
img = img.cuda()
35+
target = target.cuda()
3836

39-
# Show the lowest 10 loss values
40-
losses_record_arr = np.array(losses_record)
41-
print('Top ten loss value:', losses_record_arr[losses_idx[:10]])
37+
with torch.no_grad():
38+
output = model_ascent(img)
39+
loss = criterion(output, target)
40+
# print(loss.item())
4241

43-
return losses_idx
42+
losses_record.append(loss.item())
4443

44+
losses_idx = np.argsort(np.array(losses_record)) # get the index of examples by loss value in ascending order
4545

46-
def isolate_data(opt, poisoned_data, losses_idx):
47-
# Initialize lists
48-
other_examples = []
49-
isolation_examples = []
46+
# Show the lowest 10 loss values
47+
losses_record_arr = np.array(losses_record)
48+
print('Top ten loss value:', losses_record_arr[losses_idx[:10]])
5049

51-
cnt = 0
52-
ratio = opt.isolation_ratio
50+
return losses_idx
5351

54-
example_data_loader = DataLoader(dataset=poisoned_data,
55-
batch_size=1,
56-
shuffle=False,
57-
)
58-
# print('full_poisoned_data_idx:', len(losses_idx))
59-
perm = losses_idx[0: int(len(losses_idx) * ratio)]
52+
def isolate_data(self, data, losses_idx):
53+
# Initialize lists
54+
other_examples = []
55+
isolation_examples = []
6056

61-
for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)):
62-
img = img.squeeze()
63-
target = target.squeeze()
64-
img = np.transpose((img * 255).cpu().numpy(), (1, 2, 0)).astype('uint8')
65-
target = target.cpu().numpy()
57+
cnt = 0
58+
ratio = opt.isolation_ratio
6659

67-
# Filter the examples corresponding to losses_idx
68-
if idx in perm:
69-
isolation_examples.append((img, target))
70-
cnt += 1
71-
else:
72-
other_examples.append((img, target))
60+
example_data_loader = DataLoader(dataset=poisoned_data,
61+
batch_size=1,
62+
shuffle=False,
63+
)
64+
# print('full_poisoned_data_idx:', len(losses_idx))
65+
perm = losses_idx[0: int(len(losses_idx) * ratio)]
7366

74-
# Save data
75-
if opt.save:
76-
data_path_isolation = os.path.join(opt.isolate_data_root, "{}_isolation{}%_examples.npy".format(opt.model_name,
77-
opt.isolation_ratio * 100))
78-
data_path_other = os.path.join(opt.isolate_data_root, "{}_other{}%_examples.npy".format(opt.model_name,
79-
100 - opt.isolation_ratio * 100))
80-
if os.path.exists(data_path_isolation):
81-
raise ValueError('isolation data already exists')
82-
else:
83-
# save the isolation examples
84-
np.save(data_path_isolation, isolation_examples)
85-
np.save(data_path_other, other_examples)
67+
for idx, (img, target) in tqdm(enumerate(example_data_loader, start=0)):
68+
img = img.squeeze()
69+
target = target.squeeze()
70+
img = np.transpose((img * 255).cpu().numpy(), (1, 2, 0)).astype('uint8')
71+
target = target.cpu().numpy()
8672

87-
print('Finish collecting {} isolation examples: '.format(len(isolation_examples)))
88-
print('Finish collecting {} other examples: '.format(len(other_examples)))
73+
# Filter the examples corresponding to losses_idx
74+
if idx in perm:
75+
isolation_examples.append((img, target))
76+
cnt += 1
77+
else:
78+
other_examples.append((img, target))
8979

80+
# Save data
81+
if opt.save:
82+
data_path_isolation = os.path.join(opt.isolate_data_root,
83+
"{}_isolation{}%_examples.npy".format(opt.model_name,
84+
opt.isolation_ratio * 100))
85+
data_path_other = os.path.join(opt.isolate_data_root, "{}_other{}%_examples.npy".format(opt.model_name,
86+
100 - opt.isolation_ratio * 100))
87+
if os.path.exists(data_path_isolation):
88+
raise ValueError('isolation data already exists')
89+
else:
90+
# save the isolation examples
91+
np.save(data_path_isolation, isolation_examples)
92+
np.save(data_path_other, other_examples)
9093

91-
def train_step(opt, train_loader, model_ascent, optimizer, criterion, epoch):
92-
losses = AverageMeter()
93-
top1 = AverageMeter()
94-
top5 = AverageMeter()
94+
print('Finish collecting {} isolation examples: '.format(len(isolation_examples)))
95+
print('Finish collecting {} other examples: '.format(len(other_examples)))
9596

96-
model_ascent.train()
9797

98-
for idx, (img, target) in enumerate(train_loader, start=1):
99-
if opt.cuda:
100-
img = img.cuda()
101-
target = target.cuda()
10298

103-
if opt.gradient_ascent_type == 'LGA':
104-
output = model_ascent(img)
105-
loss = criterion(output, target)
106-
# add Local Gradient Ascent(LGA) loss
107-
loss_ascent = torch.sign(loss - opt.gamma) * loss
10899

109-
elif opt.gradient_ascent_type == 'Flooding':
110-
output = model_ascent(img)
111-
# output = student(img)
112-
loss = criterion(output, target)
113-
# add flooding loss
114-
loss_ascent = (loss - opt.flooding).abs() + opt.flooding
115-
116-
else:
117-
raise NotImplementedError
118-
119-
prec1, prec5 = accuracy(output, target, topk=(1, 5))
120-
losses.update(loss_ascent.item(), img.size(0))
121-
top1.update(prec1.item(), img.size(0))
122-
top5.update(prec5.item(), img.size(0))
123-
124-
optimizer.zero_grad()
125-
loss_ascent.backward()
126-
optimizer.step()
127-
128-
if idx % opt.print_freq == 0:
129-
print('Epoch[{0}]:[{1:03}/{2:03}] '
130-
'Loss:{losses.val:.4f}({losses.avg:.4f}) '
131-
'Prec@1:{top1.val:.2f}({top1.avg:.2f}) '
132-
'Prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(epoch, idx, len(train_loader), losses=losses, top1=top1, top5=top5))
133-
134-
135-
def test(opt, test_clean_loader, test_bad_loader, model_ascent, criterion, epoch):
136-
test_process = []
137-
losses = AverageMeter()
138-
top1 = AverageMeter()
139-
top5 = AverageMeter()
140-
141-
model_ascent.eval()
142-
143-
for idx, (img, target) in enumerate(test_clean_loader, start=1):
144-
if opt.cuda:
145-
img = img.cuda()
146-
target = target.cuda()
147-
148-
with torch.no_grad():
149-
output = model_ascent(img)
150-
loss = criterion(output, target)
151-
152-
prec1, prec5 = accuracy(output, target, topk=(1, 5))
153-
losses.update(loss.item(), img.size(0))
154-
top1.update(prec1.item(), img.size(0))
155-
top5.update(prec5.item(), img.size(0))
156-
157-
acc_clean = [top1.avg, top5.avg, losses.avg]
158-
159-
losses = AverageMeter()
160-
top1 = AverageMeter()
161-
top5 = AverageMeter()
162-
163-
for idx, (img, target) in enumerate(test_bad_loader, start=1):
164-
if opt.cuda:
165-
img = img.cuda()
166-
target = target.cuda()
167-
168-
with torch.no_grad():
169-
output = model_ascent(img)
170-
loss = criterion(output, target)
171-
172-
prec1, prec5 = accuracy(output, target, topk=(1, 5))
173-
losses.update(loss.item(), img.size(0))
174-
top1.update(prec1.item(), img.size(0))
175-
top5.update(prec5.item(), img.size(0))
176-
177-
acc_bd = [top1.avg, top5.avg, losses.avg]
178-
179-
print('[Clean] Prec@1: {:.2f}, Loss: {:.4f}'.format(acc_clean[0], acc_clean[2]))
180-
print('[Bad] Prec@1: {:.2f}, Loss: {:.4f}'.format(acc_bd[0], acc_bd[2]))
181-
182-
# save training progress
183-
if epoch < opt.tuning_epochs + 1:
184-
log_root = opt.log_root + '/ABL_results_tuning_epochs.csv'
185-
test_process.append(
186-
(epoch, acc_clean[0], acc_bd[0], acc_clean[2], acc_bd[2]))
187-
df = pd.DataFrame(test_process, columns=("Epoch", "Test_clean_acc", "Test_bad_acc",
188-
"Test_clean_loss", "Test_bad_loss"))
189-
df.to_csv(log_root, mode='a', index=False, encoding='utf-8')
190-
191-
return acc_clean, acc_bd
192-
193-
194-
def train(opt):
195-
# Load models
196-
print('----------- Network Initialization --------------')
197-
model_ascent, _ = select_model(dataset=opt.dataset,
198-
model_name=opt.model_name,
199-
pretrained=False,
200-
pretrained_models_path=opt.isolation_model_root,
201-
n_classes=opt.num_class)
202-
model_ascent.to(opt.device)
203-
print('finished model init...')
204-
205-
# initialize optimizer
206-
optimizer = torch.optim.SGD(model_ascent.parameters(),
207-
lr=opt.lr,
208-
momentum=opt.momentum,
209-
weight_decay=opt.weight_decay,
210-
nesterov=True)
211-
212-
# define loss functions
213-
if opt.cuda:
214-
criterion = nn.CrossEntropyLoss().cuda()
215-
else:
216-
criterion = nn.CrossEntropyLoss()
217-
218-
print('----------- Data Initialization --------------')
219-
if opt.load_fixed_data:
220-
tf_compose = transforms.Compose([
221-
transforms.ToTensor()
222-
])
223-
# load the fixed poisoned data, e.g. Dynamic, FC, DFST attacks etc.
224-
poisoned_data = np.load(opt.poisoned_data_path, allow_pickle=True)
225-
poisoned_data_loader = DataLoader(dataset=poisoned_data,
226-
batch_size=opt.batch_size,
227-
shuffle=True,
228-
)
229-
else:
230-
poisoned_data, poisoned_data_loader = get_backdoor_loader(opt)
231-
232-
test_clean_loader, test_bad_loader = get_test_loader(opt)
233-
234-
print('----------- Train Initialization --------------')
235-
for epoch in range(0, opt.tuning_epochs):
236-
237-
adjust_learning_rate(optimizer, epoch, opt)
238-
239-
# train every epoch
240-
if epoch == 0:
241-
# before training test firstly
242-
test(opt, test_clean_loader, test_bad_loader, model_ascent,
243-
criterion, epoch + 1)
244-
245-
train_step(opt, poisoned_data_loader, model_ascent, optimizer, criterion, epoch + 1)
246-
247-
# evaluate on testing set
248-
print('testing the ascended model......')
249-
acc_clean, acc_bad = test(opt, test_clean_loader, test_bad_loader, model_ascent, criterion, epoch + 1)
250-
251-
if opt.save:
252-
# remember best precision and save checkpoint
253-
# is_best = acc_clean[0] > opt.threshold_clean
254-
# opt.threshold_clean = min(acc_clean[0], opt.threshold_clean)
255-
#
256-
# best_clean_acc = acc_clean[0]
257-
# best_bad_acc = acc_bad[0]
258-
#
259-
# save_checkpoint({
260-
# 'epoch': epoch,
261-
# 'state_dict': model_ascent.state_dict(),
262-
# 'clean_acc': best_clean_acc,
263-
# 'bad_acc': best_bad_acc,
264-
# 'optimizer': optimizer.state_dict(),
265-
# }, epoch, is_best, opt.checkpoint_root, opt.model_name)
266-
267-
# save checkpoint at interval epoch
268-
if epoch % opt.interval == 0:
269-
is_best = True
270-
save_checkpoint({
271-
'epoch': epoch + 1,
272-
'state_dict': model_ascent.state_dict(),
273-
'clean_acc': acc_clean[0],
274-
'bad_acc': acc_bad[0],
275-
'optimizer': optimizer.state_dict(),
276-
}, epoch, is_best, opt)
277-
278-
return poisoned_data, model_ascent
279-
280-
281-
def adjust_learning_rate(optimizer, epoch, opt):
282-
if epoch < opt.tuning_epochs:
283-
lr = opt.lr
284-
else:
285-
lr = 0.01
286-
print('epoch: {} lr: {:.4f}'.format(epoch, lr))
287-
for param_group in optimizer.param_groups:
288-
param_group['lr'] = lr
289-
290-
291-
def save_checkpoint(state, epoch, is_best, opt):
292-
if is_best:
293-
filepath = os.path.join(opt.save, opt.model_name + r'-tuning_epochs{}.tar'.format(epoch))
294-
torch.save(state, filepath)
295-
print('[info] Finish saving the model')
296-
297-
def main():
298-
print('----------- Train isolated model -----------')
299-
opt = get_arguments().parse_args()
300-
poisoned_data, ascent_model = train(opt)
301-
302-
print('----------- Calculate loss value per example -----------')
303-
losses_idx = compute_loss_value(opt, poisoned_data, ascent_model)
304-
305-
print('----------- Collect isolation data -----------')
306-
isolate_data(opt, poisoned_data, losses_idx)
307-
308-
if (__name__ == '__main__'):
309-
main()

0 commit comments

Comments
 (0)