-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_flux_mapping_net.py
79 lines (69 loc) · 3.16 KB
/
train_flux_mapping_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
train_flux_mapping_net.py
-------------------------
Trains the flux mapping network, saving the best checkpoints.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
print(tf.__version__)
import sys
import os
import time
from flux_mapping_net import DM2Flux, plot_to_image
import numpy as np
import matplotlib.pyplot as plt
run_num = sys.argv[1]
datapath = './data/univ_000_real.hdf5'
# Set up directory
baseDir = './training_runs/'
expDir = baseDir+'fluxmapnet_'+str(run_num)+'/'
if not os.path.isdir(baseDir):
os.mkdir(baseDir)
if not os.path.isdir(expDir):
os.mkdir(expDir)
else:
print("Experiment directory %s already exists, exiting"%expDir)
sys.exit()
net = DM2Flux(datapath, expDir)
bestchi = 1e15
bestL1 = 1e15
for epoch in range(net.EPOCHS):
start = time.time()
with net.train_summary_writer.as_default():
for input_image, target in net.train_dataset:
net.train_step(input_image, target)
if tf.equal(net.generator_optimizer.iterations % net.log_freq, 0):
# Generate sample imgs
fig = net.generate_images()
tf.summary.image("genimg", plot_to_image(fig),
step=net.generator_optimizer.iterations)
fig, chi, meanL1 = net.pix_hist()
tf.summary.image("pixhist", plot_to_image(fig),
step=net.generator_optimizer.iterations)
# Log scalars
tf.summary.scalar('G_loss', net.G_loss.result(),
step=net.generator_optimizer.iterations)
tf.summary.scalar('G_loss_gan', net.G_loss_gan.result(),
step=net.generator_optimizer.iterations)
tf.summary.scalar('G_loss_L1', net.G_loss_L1.result(),
step=net.generator_optimizer.iterations)
tf.summary.scalar('D_loss', net.D_loss.result(),
step=net.generator_optimizer.iterations)
tf.summary.scalar('chi', chi, step=net.generator_optimizer.iterations)
tf.summary.scalar('meanL1', meanL1, step=net.generator_optimizer.iterations)
net.D_loss.reset_states()
net.G_loss.reset_states()
net.G_loss_gan.reset_states()
net.G_loss_L1.reset_states()
# Save model if chi, L1 is good
if net.generator_optimizer.iterations > 60000:
if chi < bestchi:
net.checkpoint.write(file_prefix = os.path.join(net.checkpoint_dir, 'BESTCHI'))
bestchi = chi
print('BESTCHI: iter=%d, chi=%f'%(net.generator_optimizer.iterations, chi))
if meanL1 < bestL1:
net.checkpoint.write(file_prefix = os.path.join(net.checkpoint_dir, 'BESTL1'))
bestL1 = meanL1
print('BESTL1: iter=%d, L1=%f'%(net.generator_optimizer.iterations, meanL1))
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))
print('DONE')