-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
93 lines (72 loc) · 3.07 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.optim
import json
import torch.utils.data.sampler
import os
import glob
import random
import time
import configs
import backbone
import data.feature_loader as feat_loader
from data.datamgr import SetDataManager
from methods.baselinetrain import BaselineTrain
from methods.baselinefinetune import BaselineFinetune
from methods.SSL_train import SSL_Train
from methods.SSL_finetune import SSL_Finetune
from io_utils import model_dict, parse_args, get_resume_file, get_best_file , get_assigned_file
from tqdm import *
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def feature_evaluation(cl_data_file, model, n_way = 5, n_support = 5, n_query = 15, adaptation = False):
class_list = cl_data_file.keys()
select_class = random.sample(class_list,n_way)
z_all = []
for cl in select_class:
img_feat = cl_data_file[cl]
perm_ids = np.random.permutation(len(img_feat)).tolist()
z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] ) # stack each batch
z_all = torch.from_numpy(np.array(z_all) )
model.n_query = n_query
if adaptation:
scores = model.set_forward_adaptation(z_all, is_feature = True)
else:
scores = model.set_forward(z_all, is_feature = True)
pred = scores.data.cpu().numpy().argmax(axis = 1)
y = np.repeat(range( n_way ), n_query )
acc = np.mean(pred == y)*100
return acc
if __name__ == '__main__':
params = parse_args('test')
acc_all = []
iter_num = 600
few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot)
if params.method == 'baseline':
model = BaselineFinetune( model_dict[params.model], **few_shot_params )
elif params.method == 'baseline++':
model = BaselineFinetune( model_dict[params.model], loss_type = 'dist', **few_shot_params )
elif params.method == 'SSL':
model = SSL_Finetune( model_dict[params.model], **few_shot_params )
else:
raise ValueError('Unknown method')
model = model.cuda()
checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
if params.train_aug:
checkpoint_dir += '_aug'
split = params.split
if params.save_iter != -1:
split_str = split + "_" +str(params.save_iter)
else:
split_str = split
novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"), split_str +".hdf5") #defaut split = novel, but you can also test base or val classes
novel_file = './' + novel_file
cl_data_file = feat_loader.init_loader(novel_file)
for i in tqdm(range(iter_num)):
acc = feature_evaluation(cl_data_file, model, n_query = 15, adaptation = params.adaptation, **few_shot_params)
acc_all.append(acc)
acc_all = np.asarray(acc_all)
acc_mean = np.mean(acc_all)
acc_std = np.std(acc_all)
print('%d Test Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num)))