-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrainstyle.py
executable file
·133 lines (108 loc) · 6.04 KB
/
trainstyle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python
# coding: utf-8
from __future__ import print_function
import os
import numpy as np
from scipy.misc import imread
from optimize import optimize
from argparse import ArgumentParser, RawTextHelpFormatter
import quickpaint
import glob
DEVICE = '/gpu:0'
FRAC_GPU = 1
def get_opts():
parser = ArgumentParser(description="Train neural network on the COCO data set using a specific style.",
formatter_class=RawTextHelpFormatter,
usage="./trainstyle.py -s [ style ] -c [ checkpoint dir ] -o [ output test image ]"
"-i [ checkpoint iterations ] -od [ output test dir ] -cw [ content weight ] "
"-tv [ tv weight ] -sw [ style weight ] - b [ batch size ] -l [ learning rate ] "
"-n [ network path ]"
"Example: ./trainstyle.py -s styles/the_scream.jpg -c checkpoint -o stanford.jpg"
"-od test -cw 1.5e1 -i 1000 -b 20")
parser.add_argument('-c', '--checkpoint-dir', type=str, default='checkpoint',
dest='checkpoint_dir', help='dir to save checkpoint in',
metavar='CHECKPOINT_DIR', required=True)
parser.add_argument('-s', '--style', type=str,
dest='style', help='desired style image path',
metavar='STYLE', required=True)
parser.add_argument('-t', '--train-path', type=str, default='data/train2014',
dest='train_path', help='path to training images folder',
metavar='TRAIN_PATH')
parser.add_argument('-o', '--output', type=str,
dest='output', help='output test image at every checkpoint path',
metavar='OUTPUT', default=False)
parser.add_argument('-od', '--output-dir', type=str,
dest='output_dir', help='output test images dir',
metavar='OUTPUT DIR', default=False)
parser.add_argument('-e', '--epochs', type=int, default=2,
dest='epochs', help='# of epochs', metavar='EPOCHS')
parser.add_argument('-b', '--batch-size', type=int, default=4,
dest='batch_size', help='batch size',
metavar='BATCH_SIZE')
parser.add_argument('-i', '--checkpoint-iterations', type=int, default=2000,
dest='checkpoint_iterations', help='checkpoint frequency',
metavar='CHECKPOINT_ITERATIONS')
parser.add_argument('-n', '--net-path', type=str, default='data/imagenet-vgg-verydeep-19.mat',
dest='net_path', help='path to VGG19 network (default %(default)s)',
metavar='NET_PATH')
parser.add_argument('-cw', '--content-weight', type=float, default=7.5e0,
dest='content_weight', help='content weight (default %(default)s)',
metavar='CONTENT_WEIGHT')
parser.add_argument('-sw', '--style-weight', type=float, default=1e2,
dest='style_weight', help='style weight (default %(default)s)',
metavar='STYLE_WEIGHT')
parser.add_argument('-tw', '--tv-weight', type=float, default=2e2,
dest='tv_weight', help='total variation regularization weight (default %(default)s)',
metavar='TV_WEIGHT')
parser.add_argument('-l', '--learning-rate', type=float, default=1e-3,
dest='learning_rate', help='learning rate (default %(default)s)',
metavar='LEARNING_RATE')
opts = parser.parse_args()
# check opts
assert os.path.exists(opts.style), 'style image not found.. %s does not exist!' % opts.style
assert os.path.exists(opts.train_path), 'train path not found.. %s does not exist!' % opts.train_path
assert os.path.exists(opts.net_path), 'Network not found.. %s does not exist!' % opts.net_path
assert os.path.exists(opts.output), 'Output test not found.. %s does not exist' % opts.output
if not os.path.exists(opts.output_dir):
print('creating output tests dir')
os.makedirs(opts.output_dir)
if os.path.isdir(opts.checkpoint_dir):
if not os.path.exists(opts.checkpoint_dir):
print('creating checkpoints dir')
os.makedirs(opts.checkpoint_dir)
assert opts.epochs > 0
assert opts.batch_size > 0
assert opts.checkpoint_iterations > 0
assert opts.content_weight >= 0
assert opts.style_weight >= 0
assert opts.tv_weight >= 0
assert opts.learning_rate >= 0
return opts
# read image using scipy
def read_img(src):
img = imread(src, mode='RGB')
if not (len(img.shape) == 3 and img.shape[2] == 3):
img = np.dstack((img, img, img))
return img
def main():
opts = get_opts()
style_target = read_img(opts.style)
content_targets = glob.glob('%s/*' % opts.train_path)
style_name = os.path.splitext(os.path.basename(opts.style))[0]
kwargs = {"epochs": opts.epochs, "print_iterations": opts.checkpoint_iterations,
"batch_size": opts.batch_size, "save_path": os.path.join(opts.checkpoint_dir, '%s.ckpt' % style_name),
"learning_rate": opts.learning_rate}
args = [content_targets, style_target, opts.content_weight, opts.style_weight, opts.tv_weight,
opts.net_path]
for preds, losses, i, epoch in optimize(*args, **kwargs):
style_loss, content_loss, tv_loss, loss = losses
print('Epoch %d, Iteration: %d, Loss: %s' % (epoch, i, loss))
to_print = (style_loss, content_loss, tv_loss)
print('style: %s, content:%s, tv: %s' % to_print)
if opts.output:
preds_path = '%s/%s_%s.png' % (opts.output_dir, epoch, i)
quickpaint.eval_mul_dims(opts.output, preds_path, opts.checkpoint_dir)
cmd_text = 'python quickpaint.py --checkpoint %s ...' % opts.checkpoint_dir
print("Training complete. For evaluation:\n `%s`" % cmd_text)
if __name__ == '__main__':
main()