-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathmain.py
60 lines (54 loc) · 1.93 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: GANimorph.py
# Author: Aaron Gokaslan ([email protected])
import cv2
import os, sys
import argparse
from six.moves import map, zip
import numpy as np
from glob import glob
from model import Model
from tensorpack import *
from tensorpack.utils.viz import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
import tensorflow as tf
from tensorflow.python.training import moving_averages
from utils import *
from GAN import GANTrainer, MultiGPUGANTrainer, SeparateGANTrainer, GANModelDesc
"""
The official code for Improved Shape Deformation in Unsupervised Image to Image
Translation.
Requires Tensorpack and related dependencies.
author: Aaron Gokaslan [email protected]
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--data', required=True,
help='the img_align_celeba directory. should also contain list_attr_celeba.txt')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
if __name__ == '__main__':
logger.auto_set_dir()
data = get_data(args.data)
data = QueueInput(data)
# train 1 D after 2 G
SeparateGANTrainer(data, Model(),2).train_with_defaults(
callbacks=[
PeriodicTrigger(ModelSaver(), every_k_epochs=20),
PeriodicTrigger(VisualizeTestSet(args.data), every_k_epochs=3),
ScheduledHyperParamSetter(
'learning_rate',
[(150, 2e-4), (300, 0)], interp='linear')],
steps_per_epoch=1000,
session_init=SaverRestore(args.load) if args.load else None
)
#SeparateGANTrainer(config, 2).train()
# If you want to run across GPUs use code similar to below.
#nr_gpu = get_nr_gpu()
#config.nr_tower = max(get_nr_gpu(), 1)
#if config.nr_tower == 1:
# GANTrainer(config).train()
#else:
# MultiGPUGANTrainer(config).train()