1
+ """Analyzes, visualizes knowledge distillation"""
2
+
3
+ import argparse
4
+ import logging
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ import torch .nn .functional as F
9
+ from torch .autograd import Variable
10
+ import utils
11
+ import model .net as net
12
+ import model .resnet as resnet
13
+ import model .data_loader as data_loader
14
+ from torchnet .meter import ConfusionMeter
15
+ from tqdm import tqdm
16
+
17
+ parser = argparse .ArgumentParser ()
18
+ parser .add_argument ('--model_dir' , default = 'experiments/base_model' , help = "Directory of params.json" )
19
+ parser .add_argument ('--restore_file' , default = 'best' , help = "name of the file in --model_dir \
20
+ containing weights to load" )
21
+ parser .add_argument ('--dataset' , default = 'dev' , help = "dataset to analze the model on" )
22
+ parser .add_argument ('--temperature' , type = float , default = 1.0 , \
23
+ help = "temperature used for softmax output" )
24
+
25
+
26
+ def model_analysis (model , dataloader , params , temperature = 1. , num_classes = 10 ):
27
+ """
28
+ Generate Confusion Matrix on evaluation set
29
+ """
30
+ model .eval ()
31
+ confusion_matrix = ConfusionMeter (num_classes )
32
+ softmax_scores = []
33
+ predict_correct = []
34
+
35
+ with tqdm (total = len (dataloader )) as t :
36
+ for idx , (data_batch , labels_batch ) in enumerate (dataloader ):
37
+
38
+ if params .cuda :
39
+ data_batch , labels_batch = data_batch .cuda (async = True ), \
40
+ labels_batch .cuda (async = True )
41
+ data_batch , labels_batch = Variable (data_batch ), Variable (labels_batch )
42
+
43
+ output_batch = model (data_batch )
44
+
45
+ confusion_matrix .add (output_batch .data , labels_batch .data )
46
+
47
+ softmax_scores_batch = F .softmax (output_batch / temperature , dim = 1 )
48
+ softmax_scores_batch = softmax_scores_batch .data .cpu ().numpy ()
49
+ softmax_scores .append (softmax_scores_batch )
50
+
51
+ # extract data from torch Variable, move to cpu, convert to numpy arrays
52
+ output_batch = output_batch .data .cpu ().numpy ()
53
+ labels_batch = labels_batch .data .cpu ().numpy ()
54
+
55
+ predict_correct_batch = (np .argmax (output_batch , axis = 1 ) == labels_batch ).astype (int )
56
+ predict_correct .append (np .reshape (predict_correct_batch , (labels_batch .size , 1 )))
57
+
58
+ t .update ()
59
+
60
+ softmax_scores = np .vstack (softmax_scores )
61
+ predict_correct = np .vstack (predict_correct )
62
+
63
+ return softmax_scores , predict_correct , confusion_matrix .value ().astype (int )
64
+
65
+
66
+ if __name__ == '__main__' :
67
+ """
68
+ Evaluate the model on the test set.
69
+ """
70
+ # Load the parameters
71
+ args = parser .parse_args ()
72
+ json_path = os .path .join (args .model_dir , 'params.json' )
73
+ assert os .path .isfile (json_path ), "No json configuration file found at {}" .format (json_path )
74
+ params = utils .Params (json_path )
75
+
76
+ # use GPU if available
77
+ params .cuda = torch .cuda .is_available () # use GPU is available
78
+
79
+ # Set the random seed for reproducible experiments
80
+ torch .manual_seed (230 )
81
+ if params .cuda : torch .cuda .manual_seed (230 )
82
+
83
+ # Get the logger
84
+ utils .set_logger (os .path .join (args .model_dir , 'analysis.log' ))
85
+
86
+ # Create the input data pipeline
87
+ logging .info ("Loading the dataset..." )
88
+
89
+ # fetch dataloaders
90
+ # train_dl = data_loader.fetch_dataloader('train', params)
91
+ # dev_dl = data_loader.fetch_dataloader('dev', params)
92
+ dataloader = data_loader .fetch_dataloader (args .dataset , params )
93
+
94
+ logging .info ("- done." )
95
+
96
+ # Define the model graph
97
+ model = resnet .ResNet18 ().cuda () if params .cuda else resnet .ResNet18 ()
98
+
99
+ # fetch loss function and metrics
100
+ metrics = resnet .metrics
101
+
102
+ logging .info ("Starting analysis..." )
103
+
104
+ # Reload weights from the saved file
105
+ utils .load_checkpoint (os .path .join (args .model_dir , args .restore_file + '.pth.tar' ), model )
106
+
107
+ # Evaluate and analyze
108
+ softmax_scores , predict_correct , confusion_matrix = model_analysis (model , dataloader , params ,
109
+ args .temperature )
110
+
111
+ results = {'softmax_scores' : softmax_scores , 'predict_correct' : predict_correct ,
112
+ 'confusion_matrix' : confusion_matrix }
113
+
114
+ for k , v in results .items ():
115
+ filename = args .dataset + '_temp' + str (args .temperature ) + '_' + k + '.txt'
116
+ save_path = os .path .join (args .model_dir , filename )
117
+ np .savetxt (save_path , v )
0 commit comments