Skip to content

Commit 2147781

Browse files
author
Robert Coleman
authored
added batch normalization script & notebook
1 parent 7de87eb commit 2147781

File tree

2 files changed

+467
-0
lines changed

2 files changed

+467
-0
lines changed

cats_n_dogs_BN.ipynb

+362
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"## Image Classification of Dogs vs. Cats Using CNN Ensemble"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"Imports & environment"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": 1,
20+
"metadata": {
21+
"collapsed": false
22+
},
23+
"outputs": [
24+
{
25+
"name": "stderr",
26+
"output_type": "stream",
27+
"text": [
28+
"Using Theano backend.\n",
29+
"Using gpu device 0: GeForce GTX 980M (CNMeM is enabled with initial size: 90.0% of memory, cuDNN 5105)\n",
30+
"/home/robert/anaconda3/lib/python3.5/site-packages/theano/sandbox/cuda/__init__.py:600: UserWarning: Your cuDNN version is more recent than the one Theano officially supports. If you see any problems, try updating Theano or downgrading cuDNN to version 5.\n",
31+
" warnings.warn(warn)\n",
32+
"/home/robert/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.\n",
33+
" warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.')\n",
34+
"/home/robert/anaconda3/lib/python3.5/site-packages/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.\n",
35+
" warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.')\n"
36+
]
37+
}
38+
],
39+
"source": [
40+
"import os\n",
41+
"import numpy as np\n",
42+
"\n",
43+
"from glob import glob\n",
44+
"from shutil import copyfile\n",
45+
"from vgg_bn import Vgg16BN\n",
46+
"from keras.callbacks import ModelCheckpoint\n",
47+
"\n",
48+
"ROOT_DIR = os.getcwd()\n",
49+
"DATA_HOME_DIR = ROOT_DIR + '/data'\n",
50+
"%matplotlib inline"
51+
]
52+
},
53+
{
54+
"cell_type": "markdown",
55+
"metadata": {},
56+
"source": [
57+
"Config & Hyperparameters"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": 10,
63+
"metadata": {
64+
"collapsed": true
65+
},
66+
"outputs": [],
67+
"source": [
68+
"# paths\n",
69+
"data_path = DATA_HOME_DIR + '/' \n",
70+
"train_path = data_path + '/train/'\n",
71+
"valid_path = data_path + '/valid/'\n",
72+
"test_path = DATA_HOME_DIR + '/test/'\n",
73+
"model_path = ROOT_DIR + '/models/'\n",
74+
"submission_path = ROOT_DIR + '/submissions/'\n",
75+
"\n",
76+
"# data\n",
77+
"img_width, img_height = 224, 224\n",
78+
"batch_size = 64\n",
79+
"nb_train_samples = 23000\n",
80+
"nb_valid_samples = 2000\n",
81+
"nb_test_samples = 12500\n",
82+
"classes = [\"cats\", \"dogs\"]\n",
83+
"n_classes = len(classes)\n",
84+
"\n",
85+
"# model\n",
86+
"nb_epoch = 10\n",
87+
"nb_aug = 5\n",
88+
"lr = 0.001"
89+
]
90+
},
91+
{
92+
"cell_type": "markdown",
93+
"metadata": {},
94+
"source": [
95+
"Build the VGG model w/ Batch Normalization"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": 3,
101+
"metadata": {
102+
"collapsed": false,
103+
"scrolled": true
104+
},
105+
"outputs": [
106+
{
107+
"name": "stdout",
108+
"output_type": "stream",
109+
"text": [
110+
"____________________________________________________________________________________________________\n",
111+
"Layer (type) Output Shape Param # Connected to \n",
112+
"====================================================================================================\n",
113+
"lambda_1 (Lambda) (None, 3, 224, 224) 0 lambda_input_1[0][0] \n",
114+
"____________________________________________________________________________________________________\n",
115+
"zeropadding2d_1 (ZeroPadding2D) (None, 3, 226, 226) 0 lambda_1[0][0] \n",
116+
"____________________________________________________________________________________________________\n",
117+
"convolution2d_1 (Convolution2D) (None, 64, 224, 224) 0 zeropadding2d_1[0][0] \n",
118+
"____________________________________________________________________________________________________\n",
119+
"zeropadding2d_2 (ZeroPadding2D) (None, 64, 226, 226) 0 convolution2d_1[0][0] \n",
120+
"____________________________________________________________________________________________________\n",
121+
"convolution2d_2 (Convolution2D) (None, 64, 224, 224) 0 zeropadding2d_2[0][0] \n",
122+
"____________________________________________________________________________________________________\n",
123+
"maxpooling2d_1 (MaxPooling2D) (None, 64, 112, 112) 0 convolution2d_2[0][0] \n",
124+
"____________________________________________________________________________________________________\n",
125+
"zeropadding2d_3 (ZeroPadding2D) (None, 64, 114, 114) 0 maxpooling2d_1[0][0] \n",
126+
"____________________________________________________________________________________________________\n",
127+
"convolution2d_3 (Convolution2D) (None, 128, 112, 112) 0 zeropadding2d_3[0][0] \n",
128+
"____________________________________________________________________________________________________\n",
129+
"zeropadding2d_4 (ZeroPadding2D) (None, 128, 114, 114) 0 convolution2d_3[0][0] \n",
130+
"____________________________________________________________________________________________________\n",
131+
"convolution2d_4 (Convolution2D) (None, 128, 112, 112) 0 zeropadding2d_4[0][0] \n",
132+
"____________________________________________________________________________________________________\n",
133+
"maxpooling2d_2 (MaxPooling2D) (None, 128, 56, 56) 0 convolution2d_4[0][0] \n",
134+
"____________________________________________________________________________________________________\n",
135+
"zeropadding2d_5 (ZeroPadding2D) (None, 128, 58, 58) 0 maxpooling2d_2[0][0] \n",
136+
"____________________________________________________________________________________________________\n",
137+
"convolution2d_5 (Convolution2D) (None, 256, 56, 56) 0 zeropadding2d_5[0][0] \n",
138+
"____________________________________________________________________________________________________\n",
139+
"zeropadding2d_6 (ZeroPadding2D) (None, 256, 58, 58) 0 convolution2d_5[0][0] \n",
140+
"____________________________________________________________________________________________________\n",
141+
"convolution2d_6 (Convolution2D) (None, 256, 56, 56) 0 zeropadding2d_6[0][0] \n",
142+
"____________________________________________________________________________________________________\n",
143+
"zeropadding2d_7 (ZeroPadding2D) (None, 256, 58, 58) 0 convolution2d_6[0][0] \n",
144+
"____________________________________________________________________________________________________\n",
145+
"convolution2d_7 (Convolution2D) (None, 256, 56, 56) 0 zeropadding2d_7[0][0] \n",
146+
"____________________________________________________________________________________________________\n",
147+
"maxpooling2d_3 (MaxPooling2D) (None, 256, 28, 28) 0 convolution2d_7[0][0] \n",
148+
"____________________________________________________________________________________________________\n",
149+
"zeropadding2d_8 (ZeroPadding2D) (None, 256, 30, 30) 0 maxpooling2d_3[0][0] \n",
150+
"____________________________________________________________________________________________________\n",
151+
"convolution2d_8 (Convolution2D) (None, 512, 28, 28) 0 zeropadding2d_8[0][0] \n",
152+
"____________________________________________________________________________________________________\n",
153+
"zeropadding2d_9 (ZeroPadding2D) (None, 512, 30, 30) 0 convolution2d_8[0][0] \n",
154+
"____________________________________________________________________________________________________\n",
155+
"convolution2d_9 (Convolution2D) (None, 512, 28, 28) 0 zeropadding2d_9[0][0] \n",
156+
"____________________________________________________________________________________________________\n",
157+
"zeropadding2d_10 (ZeroPadding2D) (None, 512, 30, 30) 0 convolution2d_9[0][0] \n",
158+
"____________________________________________________________________________________________________\n",
159+
"convolution2d_10 (Convolution2D) (None, 512, 28, 28) 0 zeropadding2d_10[0][0] \n",
160+
"____________________________________________________________________________________________________\n",
161+
"maxpooling2d_4 (MaxPooling2D) (None, 512, 14, 14) 0 convolution2d_10[0][0] \n",
162+
"____________________________________________________________________________________________________\n",
163+
"zeropadding2d_11 (ZeroPadding2D) (None, 512, 16, 16) 0 maxpooling2d_4[0][0] \n",
164+
"____________________________________________________________________________________________________\n",
165+
"convolution2d_11 (Convolution2D) (None, 512, 14, 14) 0 zeropadding2d_11[0][0] \n",
166+
"____________________________________________________________________________________________________\n",
167+
"zeropadding2d_12 (ZeroPadding2D) (None, 512, 16, 16) 0 convolution2d_11[0][0] \n",
168+
"____________________________________________________________________________________________________\n",
169+
"convolution2d_12 (Convolution2D) (None, 512, 14, 14) 0 zeropadding2d_12[0][0] \n",
170+
"____________________________________________________________________________________________________\n",
171+
"zeropadding2d_13 (ZeroPadding2D) (None, 512, 16, 16) 0 convolution2d_12[0][0] \n",
172+
"____________________________________________________________________________________________________\n",
173+
"convolution2d_13 (Convolution2D) (None, 512, 14, 14) 0 zeropadding2d_13[0][0] \n",
174+
"____________________________________________________________________________________________________\n",
175+
"maxpooling2d_5 (MaxPooling2D) (None, 512, 7, 7) 0 convolution2d_13[0][0] \n",
176+
"____________________________________________________________________________________________________\n",
177+
"flatten_1 (Flatten) (None, 25088) 0 maxpooling2d_5[0][0] \n",
178+
"____________________________________________________________________________________________________\n",
179+
"dense_1 (Dense) (None, 4096) 0 flatten_1[0][0] \n",
180+
"____________________________________________________________________________________________________\n",
181+
"batchnormalization_1 (BatchNormal(None, 4096) 0 dense_1[0][0] \n",
182+
"____________________________________________________________________________________________________\n",
183+
"dropout_1 (Dropout) (None, 4096) 0 batchnormalization_1[0][0] \n",
184+
"____________________________________________________________________________________________________\n",
185+
"dense_2 (Dense) (None, 4096) 0 dropout_1[0][0] \n",
186+
"____________________________________________________________________________________________________\n",
187+
"batchnormalization_2 (BatchNormal(None, 4096) 0 dense_2[0][0] \n",
188+
"____________________________________________________________________________________________________\n",
189+
"dropout_2 (Dropout) (None, 4096) 0 batchnormalization_2[0][0] \n",
190+
"____________________________________________________________________________________________________\n",
191+
"dense_4 (Dense) (None, 2) 8194 dropout_2[0][0] \n",
192+
"====================================================================================================\n",
193+
"Total params: 8194\n",
194+
"____________________________________________________________________________________________________\n"
195+
]
196+
}
197+
],
198+
"source": [
199+
"vgg = Vgg16BN(size=(img_width, img_height), n_classes=n_classes, batch_size=batch_size, lr=lr)\n",
200+
"model = vgg.model\n",
201+
"\n",
202+
"model.summary()"
203+
]
204+
},
205+
{
206+
"cell_type": "code",
207+
"execution_count": 12,
208+
"metadata": {
209+
"collapsed": true
210+
},
211+
"outputs": [],
212+
"source": [
213+
"info_string = \"{0}x{1}_{2}epoch_{3}aug_{4}lr_vgg16-bn\".format(img_width, img_height, nb_epoch, nb_aug, lr)\n",
214+
"ckpt_fn = model_path + '{val_loss:.2f}-loss_' + info_string + '.h5'\n",
215+
"\n",
216+
"ckpt = ModelCheckpoint(filepath=ckpt_fn,\n",
217+
" monitor='val_loss',\n",
218+
" save_best_only=True,\n",
219+
" save_weights_only=True)"
220+
]
221+
},
222+
{
223+
"cell_type": "markdown",
224+
"metadata": {},
225+
"source": [
226+
"Train the Model"
227+
]
228+
},
229+
{
230+
"cell_type": "code",
231+
"execution_count": 13,
232+
"metadata": {
233+
"collapsed": false
234+
},
235+
"outputs": [],
236+
"source": [
237+
"vgg.fit(train_path, valid_path,\n",
238+
" nb_trn_samples=nb_train_samples,\n",
239+
" nb_val_samples=nb_valid_samples,\n",
240+
" nb_epoch=nb_epoch,\n",
241+
" callbacks=[ckpt],\n",
242+
" aug=nb_aug)"
243+
]
244+
},
245+
{
246+
"cell_type": "markdown",
247+
"metadata": {},
248+
"source": [
249+
"Predict on Test Data"
250+
]
251+
},
252+
{
253+
"cell_type": "code",
254+
"execution_count": 11,
255+
"metadata": {
256+
"collapsed": false
257+
},
258+
"outputs": [
259+
{
260+
"name": "stdout",
261+
"output_type": "stream",
262+
"text": [
263+
"Generating predictions for Augmentation... 0\n",
264+
"Found 12500 images belonging to 1 classes.\n",
265+
"Generating predictions for Augmentation... 1\n",
266+
"Found 12500 images belonging to 1 classes.\n",
267+
"Generating predictions for Augmentation... 2\n",
268+
"Found 12500 images belonging to 1 classes.\n",
269+
"Generating predictions for Augmentation... 3\n",
270+
"Found 12500 images belonging to 1 classes.\n",
271+
"Generating predictions for Augmentation... 4\n",
272+
"Found 12500 images belonging to 1 classes.\n",
273+
"Averaging Predictions Across Augmentations...\n"
274+
]
275+
}
276+
],
277+
"source": [
278+
"# generate predictions\n",
279+
"for aug in range(nb_aug):\n",
280+
" print(\"Generating predictions for Augmentation {0}...\",format(aug+1))\n",
281+
" if aug == 0:\n",
282+
" predictions, filenames = vgg.test(test_path, nb_test_samples, aug=nb_aug)\n",
283+
" else:\n",
284+
" aug_pred, filenames = vgg.test(test_path, nb_test_samples, aug=nb_aug)\n",
285+
" predictions += aug_pred\n",
286+
"\n",
287+
"print(\"Averaging Predictions Across Augmentations...\")\n",
288+
"predictions /= nb_aug"
289+
]
290+
},
291+
{
292+
"cell_type": "code",
293+
"execution_count": 14,
294+
"metadata": {
295+
"collapsed": false
296+
},
297+
"outputs": [],
298+
"source": [
299+
"# clip predictions\n",
300+
"c = 0.01\n",
301+
"preds = np.clip(predictions, c, 1-c)"
302+
]
303+
},
304+
{
305+
"cell_type": "code",
306+
"execution_count": 15,
307+
"metadata": {
308+
"collapsed": false
309+
},
310+
"outputs": [
311+
{
312+
"name": "stdout",
313+
"output_type": "stream",
314+
"text": [
315+
"Writing Predictions to CSV...\n",
316+
"0 / 12500\n",
317+
"2500 / 12500\n",
318+
"5000 / 12500\n",
319+
"7500 / 12500\n",
320+
"10000 / 12500\n",
321+
"Done.\n"
322+
]
323+
}
324+
],
325+
"source": [
326+
"sub_file = submission_path + info_string + '.csv'\n",
327+
"\n",
328+
"with open(sub_file, 'w') as f:\n",
329+
" print(\"Writing Predictions to CSV...\")\n",
330+
" f.write('id,label\\n')\n",
331+
" for i, image_name in enumerate(filenames):\n",
332+
" pred = ['%.6f' % p for p in preds[i, :]]\n",
333+
" if i % 2500 == 0:\n",
334+
" print(i, '/', nb_test_samples)\n",
335+
" f.write('%s,%s\\n' % (os.path.basename(image_name).replace('.jpg', ''), (pred[1])))\n",
336+
" print(\"Done.\")"
337+
]
338+
}
339+
],
340+
"metadata": {
341+
"anaconda-cloud": {},
342+
"kernelspec": {
343+
"display_name": "Python [conda root]",
344+
"language": "python",
345+
"name": "conda-root-py"
346+
},
347+
"language_info": {
348+
"codemirror_mode": {
349+
"name": "ipython",
350+
"version": 3
351+
},
352+
"file_extension": ".py",
353+
"mimetype": "text/x-python",
354+
"name": "python",
355+
"nbconvert_exporter": "python",
356+
"pygments_lexer": "ipython3",
357+
"version": "3.5.2"
358+
}
359+
},
360+
"nbformat": 4,
361+
"nbformat_minor": 1
362+
}

0 commit comments

Comments
 (0)