-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathtrain.py
151 lines (108 loc) · 4.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn as nn
import torch.optim as optim
import subprocess
import os
from models.Discriminator import Discriminator
from models.Generator import Generator
from utils import init_weight
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
def train(FLAGS):
# Define the hyperparameters
p_every = FLAGS.p_every
s_every = FLAGS.s_every
epochs = FLAGS.epochs
dlr = FLAGS.dlr
glr = FLAGS.glr
beta1 = FLAGS.beta1
beta2 = FLAGS.beta2
z_size = FLAGS.zsize
batch_size = FLAGS.batch_size
rh = FLAGS.resize_height
rw = FLAGS.resize_width
d_path = FLAGS.dataset_path
d_type = FLAGS.dataset_type
# Preprocessing Data
transform = transforms.Compose([transforms.Resize((rh, rw)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
if FLAGS.dataset_path == None:
if d_type == "cars":
if not os.path.exists('./datasets/cars_train'):
os.system('sh ./datasets/dload.sh cars')
d_path = './datasets/cars_train/'
elif d_type == "flowers":
if not os.path.exists('./datasets/flowers/'):
os.system('sh ./datasets/dload.sh flowers')
d_path = './datasets/flowers/'
elif d_type == "dogs":
if not os.path.exists('./datasets/jpg'):
os.system('sh ./datasets/dload.sh dogs')
d_path = './datasets/jpg/'
train_data = datasets.ImageFolder(d_path, transform=transform)
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# Define the D and G
dis = Discriminator(64)
gen = Generator()
# Apply weight initialization
dis.apply(init_weight)
gen.apply(init_weight)
# Define the loss function
criterion = nn.BCELoss()
# Optimizers
d_opt = optim.Adam(dis.parameters(), lr=dlr, betas=(beta1, beta2))
g_opt = optim.Adam(gen.parameters(), lr=glr, betas=(beta1, beta2))
# Train loop
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
train_losses = []
eval_losses = []
dis.to(device)
gen.to(device)
real_label = 1
fake_label = 0
for e in range(epochs):
td_loss = 0
tg_loss = 0
for batch_i, (real_images, _) in enumerate(trainloader):
real_images = real_images.to(device)
batch_size = real_images.size(0)
#### Train the Discriminator ####
d_opt.zero_grad()
d_real = dis(real_images)
label = torch.full((batch_size,), real_label, device=device)
r_loss = criterion(d_real, label)
r_loss.backward()
z = torch.randn(batch_size, z_size, 1, 1, device=device)
fake_images = gen(z)
label.fill_(fake_label)
d_fake = dis(fake_images.detach())
f_loss = criterion(d_fake, label)
f_loss.backward()
d_loss = r_loss + f_loss
d_opt.step()
#### Train the Generator ####
g_opt.zero_grad()
label.fill_(real_label)
d_fake2 = dis(fake_images)
g_loss = criterion(d_fake2, label)
g_loss.backward()
g_opt.step()
if batch_i % p_every == 0:
print ('Epoch [{:5d} / {:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'. \
format(e+1, epochs, d_loss, g_loss))
train_losses.append([td_loss, tg_loss])
if e % s_every == 0:
d_ckpt = {
'model_state_dict' : dis.state_dict(),
'opt_state_dict' : d_opt.state_dict()
}
g_ckpt = {
'model_state_dict' : gen.state_dict(),
'opt_state_dict' : g_opt.state_dict()
}
torch.save(d_ckpt, 'd-nm-{}.pth'.format(e))
torch.save(g_ckpt, 'g-nm-{}.pth'.format(e))
utils.save_image(fake_images.detach(), 'fake_{}.png'.format(e), normalize=True)
print ('[INFO] Training Completed successfully!')