-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathoptimize.py
executable file
·134 lines (109 loc) · 5.47 KB
/
optimize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import print_function
import functools
import vgg, time
import tensorflow as tf, numpy as np
import transform
from operator import mul
from scipy.misc import imread
STYLE_LAYERS = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1')
CONTENT_LAYER = 'relu4_2'
DEVICES = 'CUDA_VISIBLE_DEVICES'
def _tensor_size(tensor):
return functools.reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)
# read image using scipy
def read_img(src):
img = imread(src, mode='RGB')
if not (len(img.shape) == 3 and img.shape[2] == 3):
img = np.dstack((img, img, img))
return img
def optimize(content_targets, style_target, content_weight, style_weight,
tv_weight, vgg_path, epochs=2, print_iterations=1000,
batch_size=4, save_path='saver/fns.ckpt',
learning_rate=1e-3, debug=False):
mod = len(content_targets) % batch_size
if mod > 0:
print("Train set has been trimmed slightly..")
content_targets = content_targets[:-mod]
style_features = {}
batch_shape = (batch_size,256,256,3)
style_shape = (1,) + style_target.shape
print(style_shape)
# pre-compute style features
with tf.Graph().as_default(), tf.device('/cpu:0'), tf.Session() as sess:
style_image = tf.placeholder(tf.float32, shape=style_shape, name='style_image')
style_image_pre = vgg.preprocess(style_image)
net = vgg.net(vgg_path, style_image_pre)
style_pre = np.array([style_target])
for layer in STYLE_LAYERS:
features = net[layer].eval(feed_dict={style_image:style_pre})
features = np.reshape(features, (-1, features.shape[3]))
gram = np.matmul(features.T, features) / features.size
style_features[layer] = gram
with tf.Graph().as_default(), tf.Session() as sess:
X_content = tf.placeholder(tf.float32, shape=batch_shape, name="X_content")
X_pre = vgg.preprocess(X_content)
# precompute content features
content_features = {}
content_net = vgg.net(vgg_path, X_pre)
content_features[CONTENT_LAYER] = content_net[CONTENT_LAYER]
preds = transform.net(X_content/255.0)
preds_pre = vgg.preprocess(preds)
net = vgg.net(vgg_path, preds_pre)
content_size = _tensor_size(content_features[CONTENT_LAYER])*batch_size
assert _tensor_size(content_features[CONTENT_LAYER]) == _tensor_size(net[CONTENT_LAYER])
content_loss = content_weight * (2 * tf.nn.l2_loss( net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) /
content_size)
style_losses = []
for style_layer in STYLE_LAYERS:
layer = net[style_layer]
bs, height, width, filters = map(lambda i:i.value,layer.get_shape())
size = height * width * filters
feats = tf.reshape(layer, (bs, height * width, filters))
feats_T = tf.transpose(feats, perm=[0,2,1])
grams = tf.matmul(feats_T, feats) / size
style_gram = style_features[style_layer]
style_losses.append(2 * tf.nn.l2_loss(grams - style_gram)/style_gram.size)
style_loss = style_weight * functools.reduce(tf.add, style_losses) / batch_size
# total variation de-noising
tv_y_size = _tensor_size(preds[:,1:,:,:])
tv_x_size = _tensor_size(preds[:,:,1:,:])
y_tv = tf.nn.l2_loss(preds[:,1:,:,:] - preds[:,:batch_shape[1]-1,:,:])
x_tv = tf.nn.l2_loss(preds[:,:,1:,:] - preds[:,:,:batch_shape[2]-1,:])
tv_loss = tv_weight*2*(x_tv/tv_x_size + y_tv/tv_y_size)/batch_size
loss = content_loss + style_loss + tv_loss
# overall loss
train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
sess.run(tf.global_variables_initializer())
import random
uid = random.randint(1, 100)
print("UID: %s" % uid)
for epoch in range(epochs):
num_examples = len(content_targets)
iterations = 0
while iterations * batch_size < num_examples:
start_time = time.time()
curr = iterations * batch_size
step = curr + batch_size
X_batch = np.zeros(batch_shape, dtype=np.float32)
for j, img_p in enumerate(content_targets[curr:step]):
X_batch[j] = read_img(img_p, (256,256,3)).astype(np.float32)
iterations += 1
assert X_batch.shape[0] == batch_size
feed_dict = { X_content:X_batch }
train_step.run(feed_dict=feed_dict)
end_time = time.time()
delta_time = end_time - start_time
if debug:
print("UID: %s, batch time: %s" % (uid, delta_time))
is_print_iter = int(iterations) % print_iterations == 0
is_last = epoch == epochs - 1 and iterations * batch_size >= num_examples
should_print = is_print_iter or is_last
if should_print:
to_get = [style_loss, content_loss, tv_loss, loss, preds]
test_feed_dict = { X_content:X_batch }
tup = sess.run(to_get, feed_dict = test_feed_dict)
_style_loss,_content_loss,_tv_loss,_loss,_preds = tup
losses = (_style_loss, _content_loss, _tv_loss, _loss)
saver = tf.train.Saver()
res = saver.save(sess, save_path)
yield(_preds, losses, iterations, epoch)