Skip to content

Commit 3fc375d

Browse files
authored
Add files via upload
1 parent 0798714 commit 3fc375d

File tree

5 files changed

+887
-0
lines changed

5 files changed

+887
-0
lines changed

distillation_analysis.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Analyzes, visualizes knowledge distillation"""
2+
3+
import argparse
4+
import logging
5+
import os
6+
import numpy as np
7+
import torch
8+
import torch.nn.functional as F
9+
from torch.autograd import Variable
10+
import utils
11+
import model.net as net
12+
import model.resnet as resnet
13+
import model.data_loader as data_loader
14+
from torchnet.meter import ConfusionMeter
15+
from tqdm import tqdm
16+
17+
parser = argparse.ArgumentParser()
18+
parser.add_argument('--model_dir', default='experiments/base_model', help="Directory of params.json")
19+
parser.add_argument('--restore_file', default='best', help="name of the file in --model_dir \
20+
containing weights to load")
21+
parser.add_argument('--dataset', default='dev', help="dataset to analze the model on")
22+
parser.add_argument('--temperature', type=float, default=1.0, \
23+
help="temperature used for softmax output")
24+
25+
26+
def model_analysis(model, dataloader, params, temperature=1., num_classes=10):
27+
"""
28+
Generate Confusion Matrix on evaluation set
29+
"""
30+
model.eval()
31+
confusion_matrix = ConfusionMeter(num_classes)
32+
softmax_scores = []
33+
predict_correct = []
34+
35+
with tqdm(total=len(dataloader)) as t:
36+
for idx, (data_batch, labels_batch) in enumerate(dataloader):
37+
38+
if params.cuda:
39+
data_batch, labels_batch = data_batch.cuda(async=True), \
40+
labels_batch.cuda(async=True)
41+
data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)
42+
43+
output_batch = model(data_batch)
44+
45+
confusion_matrix.add(output_batch.data, labels_batch.data)
46+
47+
softmax_scores_batch = F.softmax(output_batch/temperature, dim=1)
48+
softmax_scores_batch = softmax_scores_batch.data.cpu().numpy()
49+
softmax_scores.append(softmax_scores_batch)
50+
51+
# extract data from torch Variable, move to cpu, convert to numpy arrays
52+
output_batch = output_batch.data.cpu().numpy()
53+
labels_batch = labels_batch.data.cpu().numpy()
54+
55+
predict_correct_batch = (np.argmax(output_batch, axis=1) == labels_batch).astype(int)
56+
predict_correct.append(np.reshape(predict_correct_batch, (labels_batch.size, 1)))
57+
58+
t.update()
59+
60+
softmax_scores = np.vstack(softmax_scores)
61+
predict_correct = np.vstack(predict_correct)
62+
63+
return softmax_scores, predict_correct, confusion_matrix.value().astype(int)
64+
65+
66+
if __name__ == '__main__':
67+
"""
68+
Evaluate the model on the test set.
69+
"""
70+
# Load the parameters
71+
args = parser.parse_args()
72+
json_path = os.path.join(args.model_dir, 'params.json')
73+
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
74+
params = utils.Params(json_path)
75+
76+
# use GPU if available
77+
params.cuda = torch.cuda.is_available() # use GPU is available
78+
79+
# Set the random seed for reproducible experiments
80+
torch.manual_seed(230)
81+
if params.cuda: torch.cuda.manual_seed(230)
82+
83+
# Get the logger
84+
utils.set_logger(os.path.join(args.model_dir, 'analysis.log'))
85+
86+
# Create the input data pipeline
87+
logging.info("Loading the dataset...")
88+
89+
# fetch dataloaders
90+
# train_dl = data_loader.fetch_dataloader('train', params)
91+
# dev_dl = data_loader.fetch_dataloader('dev', params)
92+
dataloader = data_loader.fetch_dataloader(args.dataset, params)
93+
94+
logging.info("- done.")
95+
96+
# Define the model graph
97+
model = resnet.ResNet18().cuda() if params.cuda else resnet.ResNet18()
98+
99+
# fetch loss function and metrics
100+
metrics = resnet.metrics
101+
102+
logging.info("Starting analysis...")
103+
104+
# Reload weights from the saved file
105+
utils.load_checkpoint(os.path.join(args.model_dir, args.restore_file + '.pth.tar'), model)
106+
107+
# Evaluate and analyze
108+
softmax_scores, predict_correct, confusion_matrix = model_analysis(model, dataloader, params,
109+
args.temperature)
110+
111+
results = {'softmax_scores': softmax_scores, 'predict_correct': predict_correct,
112+
'confusion_matrix': confusion_matrix}
113+
114+
for k, v in results.items():
115+
filename = args.dataset + '_temp' + str(args.temperature) + '_' + k + '.txt'
116+
save_path = os.path.join(args.model_dir, filename)
117+
np.savetxt(save_path, v)

requirements.txt

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
scipy==1.0.0
2+
numpy==1.14.0
3+
Pillow==8.1.1
4+
tabulate==0.8.2
5+
tensorflow==1.7.0rc0
6+
torch==0.3.0.post4
7+
torchvision==0.2.0
8+
tqdm==4.19.8
9+
torchnet

search_hyperparams.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""
2+
Peform hyperparemeters search
3+
4+
A brief definition/clarification of 'params.json' files:
5+
6+
"model_version": "resnet18", # "base" models or "modelname"_distill models
7+
"subset_percent": 1.0, # use full (1.0) train set or partial (<1.0) train set
8+
"augmentation": "yes", # whether to use data augmentation in data_loader
9+
"teacher": "densenet", # no need to specify this for "base" cnn/resnet18
10+
"alpha": 0.0, # only used for experiments involving distillation
11+
"temperature": 1, # only used for experiments involving distillation
12+
"learning_rate": 1e-1, # as the name suggests
13+
"batch_size": 128, # for both train/eval
14+
"num_epochs": 200, # as the name suggests
15+
"dropout_rate": 0.5, # only valid for "cnn"-related models, not in resnet18
16+
"num_channels": 32, # only valid for "cnn"-related models, not in resnet18
17+
"save_summary_steps": 100,
18+
"num_workers": 4
19+
20+
"""
21+
22+
23+
import argparse
24+
import os
25+
from subprocess import check_call
26+
import sys
27+
import utils
28+
import logging
29+
30+
31+
PYTHON = sys.executable
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument('--parent_dir', default='experiments/learning_rate',
34+
help='Directory containing params.json')
35+
36+
def launch_training_job(parent_dir, job_name, params):
37+
"""Launch training of the model with a set of hyperparameters in parent_dir/job_name
38+
39+
Args:
40+
model_dir: (string) directory containing config, weights and log
41+
data_dir: (string) directory containing the dataset
42+
params: (dict) containing hyperparameters
43+
"""
44+
# Create a new folder in parent_dir with unique_name "job_name"
45+
model_dir = os.path.join(parent_dir, job_name)
46+
if not os.path.exists(model_dir):
47+
os.makedirs(model_dir)
48+
49+
# Write parameters in json file
50+
json_path = os.path.join(model_dir, 'params.json')
51+
params.save(json_path)
52+
53+
# Launch training with this config
54+
cmd = "{python} train.py --model_dir={model_dir}".format(python=PYTHON,
55+
model_dir=model_dir)
56+
print(cmd)
57+
check_call(cmd, shell=True)
58+
59+
60+
if __name__ == "__main__":
61+
# Load the "reference" parameters from parent_dir json file
62+
args = parser.parse_args()
63+
json_path = os.path.join(args.parent_dir, 'params.json')
64+
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
65+
params = utils.Params(json_path)
66+
67+
# Set the logger
68+
utils.set_logger(os.path.join(args.parent_dir, 'search_hyperparameters.log'))
69+
70+
'''
71+
Temperature and alpha search for KD on CNN (teacher model picked in params.json)
72+
Perform hypersearch (empirical grid): distilling 'temperature', loss weight 'alpha'
73+
'''
74+
75+
# hyperparameters for KD
76+
alphas = [0.99, 0.95, 0.5, 0.1, 0.05]
77+
temperatures = [20., 10., 8., 6., 4.5, 3., 2., 1.5]
78+
79+
logging.info("Searching hyperparameters...")
80+
logging.info("alphas: {}".format(alphas))
81+
logging.info("temperatures: {}".format(temperatures))
82+
83+
for alpha in alphas:
84+
for temperature in temperatures:
85+
# [Modify] the relevant parameter in params (others remain unchanged)
86+
params.alpha = alpha
87+
params.temperature = temperature
88+
89+
# Launch job (name has to be unique)
90+
job_name = "alpha_{}_Temp_{}".format(alpha, temperature)
91+
launch_training_job(args.parent_dir, job_name, params)

0 commit comments

Comments
 (0)