forked from vikram2000b/bad-teaching-unlearning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
75 lines (63 loc) · 2.93 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from torch.nn import functional as F
import torch
from sklearn.svm import SVC
def JSDiv(p, q):
m = (p+q)/2
return 0.5*F.kl_div(torch.log(p), m) + 0.5*F.kl_div(torch.log(q), m)
# ZRF/UnLearningScore
def UnLearningScore(tmodel, gold_model, forget_dl, batch_size, device):
model_preds = []
gold_model_preds = []
with torch.no_grad():
for batch in forget_dl:
x, y, cy = batch
x = x.to(device)
model_output = tmodel(x)
gold_model_output = gold_model(x)
model_preds.append(F.softmax(model_output, dim = 1).detach().cpu())
gold_model_preds.append(F.softmax(gold_model_output, dim = 1).detach().cpu())
model_preds = torch.cat(model_preds, axis = 0)
gold_model_preds = torch.cat(gold_model_preds, axis = 0)
return 1-JSDiv(model_preds, gold_model_preds)
def entropy(p, dim = -1, keepdim = False):
return -torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=dim, keepdim=keepdim)
def collect_prob(data_loader, model):
data_loader = torch.utils.data.DataLoader(data_loader.dataset, batch_size=1, shuffle=False, num_workers = 32, prefetch_factor = 10)
prob = []
with torch.no_grad():
for batch in data_loader:
batch = [tensor.to(next(model.parameters()).device) for tensor in batch]
data, _, target = batch
output = model(data)
prob.append(F.softmax(output, dim=-1).data)
return torch.cat(prob)
def get_membership_attack_data(retain_loader, forget_loader, test_loader, model):
retain_prob = collect_prob(retain_loader, model)
forget_prob = collect_prob(forget_loader, model)
test_prob = collect_prob(test_loader, model)
X_r = torch.cat([entropy(retain_prob), entropy(test_prob)]).cpu().numpy().reshape(-1, 1)
Y_r = np.concatenate([np.ones(len(retain_prob)), np.zeros(len(test_prob))])
X_f = entropy(forget_prob).cpu().numpy().reshape(-1, 1)
Y_f = np.concatenate([np.ones(len(forget_prob))])
return X_f, Y_f, X_r, Y_r
def get_membership_attack_prob(retain_loader, forget_loader, test_loader, model):
X_f, Y_f, X_r, Y_r = get_membership_attack_data(retain_loader, forget_loader, test_loader, model)
clf = SVC(C=3,gamma='auto',kernel='rbf')
#clf = LogisticRegression(class_weight='balanced',solver='lbfgs',multi_class='multinomial')
clf.fit(X_r, Y_r)
results = clf.predict(X_f)
return results.mean()
@torch.no_grad()
def actv_dist(model1, model2, dataloader, device = 'cuda'):
sftmx = nn.Softmax(dim = 1)
distances = []
for batch in dataloader:
x, _, _ = batch
x = x.to(device)
model1_out = model1(x)
model2_out = model2(x)
diff = torch.sqrt(torch.sum(torch.square(F.softmax(model1_out, dim = 1) - F.softmax(model2_out, dim = 1)), axis = 1))
diff = diff.detach().cpu()
distances.append(diff)
distances = torch.cat(distances, axis = 0)
return distances.mean()