Skip to content

Commit 6eb6e74

Browse files
committed
FIX: Fine tuning of ImageNet models, adding checkpoint scope parameter.
1 parent 8de28bb commit 6eb6e74

File tree

5 files changed

+169
-12
lines changed

5 files changed

+169
-12
lines changed

COMMANDS.md

+20-9
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,27 @@ python eval_ssd_network.py \
8383
--batch_size=1 \
8484
--max_num_batches=10
8585

86-
87-
88-
python eval_image_classifier.py \
89-
--alsologtostderr \
90-
--checkpoint_path=${CHECKPOINT_PATH} \
86+
# =========================================================================== #
87+
# Fine tune VGG-based SSD network
88+
# =========================================================================== #
89+
DATASET_DIR=/media/paul/DataExt4/PascalVOC/dataset
90+
TRAIN_DIR=./logs/ssd_300_vgg_tmp
91+
CHECKPOINT_PATH=./checkpoints/vgg_16.ckpt
92+
python train_ssd_network.py \
93+
--train_dir=${TRAIN_DIR} \
9194
--dataset_dir=${DATASET_DIR} \
92-
--dataset_name=imagenet \
93-
--dataset_split_name=validation \
94-
--model_name=inception_v3
95-
95+
--dataset_name=pascalvoc_2007 \
96+
--dataset_split_name=train \
97+
--model_name=ssd_300_vgg \
98+
--checkpoint_path=${CHECKPOINT_PATH} \
99+
--checkpoint_model_scope=vgg_16 \
100+
--checkpoint_exclude_scopes=ssd_300_vgg/conv6,ssd_300_vgg/conv7,ssd_300_vgg/block8,ssd_300_vgg/block9,ssd_300_vgg/block10,ssd_300_vgg/block11,ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \
101+
--save_summaries_secs=60 \
102+
--save_interval_secs=600 \
103+
--weight_decay=0.00001 \
104+
--optimizer=rmsprop \
105+
--learning_rate=0.0001 \
106+
--batch_size=32
96107

97108

98109
python train_ssd_network.py --train_dir=${TRAIN_DIR} --dataset_dir=${DATASET_DIR} --checkpoint_path=${CHECKPOINT_PATH} --checkpoint_exclude_scopes=ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box --dataset_name=kitti --dataset_split_name=train --model_name=ssd_300_vgg --save_summaries_secs=60 --save_interval_secs=60 --weight_decay=0.0005 --optimizer=adam --learning_rate=0.0001 --batch_size=8

README.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,11 @@ python train_ssd_network.py \
133133
--dataset_split_name=train \
134134
--model_name=ssd_300_vgg \
135135
--checkpoint_path=${CHECKPOINT_PATH} \
136-
--checkpoint_exclude_scopes=ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \
136+
--checkpoint_model_scope=vgg_16 \
137+
--checkpoint_exclude_scopes=ssd_300_vgg/conv6,ssd_300_vgg/conv7,ssd_300_vgg/block8,ssd_300_vgg/block9,ssd_300_vgg/block10,ssd_300_vgg/block11,ssd_300_vgg/block4_box,ssd_300_vgg/block7_box,ssd_300_vgg/block8_box,ssd_300_vgg/block9_box,ssd_300_vgg/block10_box,ssd_300_vgg/block11_box \
137138
--save_summaries_secs=60 \
138139
--save_interval_secs=600 \
139-
--weight_decay=0.00001 \
140+
--weight_decay=0.0005 \
140141
--optimizer=rmsprop \
141142
--learning_rate=0.0001 \
142143
--batch_size=32

inspect_checkpoint.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""A simple script for inspect checkpoint files."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import argparse
21+
import sys
22+
23+
import numpy as np
24+
25+
from tensorflow.python import pywrap_tensorflow
26+
from tensorflow.python.platform import app
27+
from tensorflow.python.platform import flags
28+
29+
FLAGS = None
30+
31+
32+
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
33+
"""Prints tensors in a checkpoint file.
34+
35+
If no `tensor_name` is provided, prints the tensor names and shapes
36+
in the checkpoint file.
37+
38+
If `tensor_name` is provided, prints the content of the tensor.
39+
40+
Args:
41+
file_name: Name of the checkpoint file.
42+
tensor_name: Name of the tensor in the checkpoint file to print.
43+
all_tensors: Boolean indicating whether to print all tensors.
44+
"""
45+
try:
46+
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
47+
if all_tensors:
48+
var_to_shape_map = reader.get_variable_to_shape_map()
49+
for key in var_to_shape_map:
50+
print("tensor_name: ", key)
51+
print(reader.get_tensor(key))
52+
elif not tensor_name:
53+
print(reader.debug_string().decode("utf-8"))
54+
else:
55+
print("tensor_name: ", tensor_name)
56+
print(reader.get_tensor(tensor_name))
57+
except Exception as e: # pylint: disable=broad-except
58+
print(str(e))
59+
if "corrupted compressed block contents" in str(e):
60+
print("It's likely that your checkpoint file has been compressed "
61+
"with SNAPPY.")
62+
63+
64+
def parse_numpy_printoption(kv_str):
65+
"""Sets a single numpy printoption from a string of the form 'x=y'.
66+
67+
See documentation on numpy.set_printoptions() for details about what values
68+
x and y can take. x can be any option listed there other than 'formatter'.
69+
70+
Args:
71+
kv_str: A string of the form 'x=y', such as 'threshold=100000'
72+
73+
Raises:
74+
argparse.ArgumentTypeError: If the string couldn't be used to set any
75+
nump printoption.
76+
"""
77+
k_v_str = kv_str.split("=", 1)
78+
if len(k_v_str) != 2 or not k_v_str[0]:
79+
raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str)
80+
k, v_str = k_v_str
81+
printoptions = np.get_printoptions()
82+
if k not in printoptions:
83+
raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k)
84+
v_type = type(printoptions[k])
85+
if v_type is type(None):
86+
raise argparse.ArgumentTypeError(
87+
"Setting '%s' from the command line is not supported." % k)
88+
try:
89+
v = (v_type(v_str) if v_type is not bool
90+
else flags.BooleanParser().Parse(v_str))
91+
except ValueError as e:
92+
raise argparse.ArgumentTypeError(e.message)
93+
np.set_printoptions(**{k: v})
94+
95+
96+
def main(unused_argv):
97+
if not FLAGS.file_name:
98+
print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
99+
"[--tensor_name=tensor_to_print]")
100+
sys.exit(1)
101+
else:
102+
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
103+
FLAGS.all_tensors)
104+
105+
106+
if __name__ == "__main__":
107+
parser = argparse.ArgumentParser()
108+
parser.register("type", "bool", lambda v: v.lower() == "true")
109+
parser.add_argument(
110+
"--file_name", type=str, default="", help="Checkpoint filename. "
111+
"Note, if using Checkpoint V2 format, file_name is the "
112+
"shared prefix between all files in the checkpoint.")
113+
parser.add_argument(
114+
"--tensor_name",
115+
type=str,
116+
default="",
117+
help="Name of the tensor to inspect")
118+
parser.add_argument(
119+
"--all_tensors",
120+
nargs="?",
121+
const=True,
122+
type="bool",
123+
default=False,
124+
help="If True, print the values of all the tensors.")
125+
parser.add_argument(
126+
"--printoptions",
127+
nargs="*",
128+
type=parse_numpy_printoption,
129+
help="Argument for numpy.set_printoptions(), in the form 'k=v'.")
130+
FLAGS, unparsed = parser.parse_known_args()
131+
app.run(main=main, argv=[sys.argv[0]] + unparsed)

tf_utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ def add_variables_summaries(learning_rate):
179179
return summaries
180180

181181

182+
def update_model_scope(var, ckpt_scope, new_scope):
183+
return var.op.name.replace(new_scope,'vgg_16')
184+
185+
182186
def get_init_fn(flags):
183187
"""Returns a function run by the chief worker to warm-start the training.
184188
Note that the init_fn is only run when initializing the model during the very
@@ -211,6 +215,13 @@ def get_init_fn(flags):
211215
break
212216
if not excluded:
213217
variables_to_restore.append(var)
218+
# Change model scope if necessary.
219+
if flags.checkpoint_model_scope is not None:
220+
variables_to_restore = \
221+
{var.op.name.replace(flags.model_name,
222+
flags.checkpoint_model_scope): var
223+
for var in variables_to_restore}
224+
214225

215226
if tf.gfile.IsDirectory(flags.checkpoint_path):
216227
checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)

train_ssd_network.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@
141141
'evaluate the VGG and ResNet architectures which do not use a background '
142142
'class for the ImageNet dataset.')
143143
tf.app.flags.DEFINE_string(
144-
'model_name', 'inception_v3', 'The name of the architecture to train.')
144+
'model_name', 'ssd_300_vgg', 'The name of the architecture to train.')
145145
tf.app.flags.DEFINE_string(
146146
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
147147
'as `None`, then the model_name flag is used.')
@@ -158,6 +158,9 @@
158158
tf.app.flags.DEFINE_string(
159159
'checkpoint_path', None,
160160
'The path to a checkpoint from which to fine-tune.')
161+
tf.app.flags.DEFINE_string(
162+
'checkpoint_model_scope', None,
163+
'Model scope in the checkpoint. None if the same as the trained model.')
161164
tf.app.flags.DEFINE_string(
162165
'checkpoint_exclude_scopes', None,
163166
'Comma-separated list of scopes of variables to exclude when restoring '

0 commit comments

Comments
 (0)