This repository was archived by the owner on May 11, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
executable file
·388 lines (323 loc) · 18.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import time
import datetime
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib
matplotlib.use("Qt5Agg")
from dataset.dataset import TextToSpeechDatasetCollection, TextToSpeechCollate
from params.params import Params as hp
from utils import audio
from modules.tacotron2 import Tacotron, TacotronLoss
from utils.logging import Logger
from utils.samplers import RandomImbalancedSampler, PerfectBatchSampler
from utils import lengths_to_mask, to_gpu
from modules.layers import ConstantEmbedding
def cos_decay(global_step, decay_steps):
"""Cosine decay function
Arguments:
global_step -- current training step
decay_steps -- number of decay steps
"""
global_step = min(global_step, decay_steps)
return 0.5 * (1 + math.cos(math.pi * global_step / decay_steps))
def train(logging_start_epoch, epoch, data, model, criterion, optimizer):
"""Main training procedure.
Arguments:
logging_start_epoch -- number of the first epoch to be logged
epoch -- current epoch
data -- DataLoader which can provide batches for an epoch
model -- model to be trained
criterion -- instance of loss function to be optimized
optimizer -- instance of optimizer which will be used for parameter updates
"""
model.train()
# initialize counters, etc.
learning_rate = optimizer.param_groups[0]['lr']
cla = 0
done = 0
start_time = time.time()
optimizer.zero_grad() # once before loop start
# loop through epoch batches
for i, batch in enumerate(data):
global_step = done + epoch * len(data)
# moved below for gradient accumulation
# optimizer.zero_grad()
# parse batch
batch = list(map(to_gpu, batch))
src, src_len, trg_mel, trg_lin, trg_len, stop_trg, spkrs, langs = batch
# get teacher forcing ratio
if hp.constant_teacher_forcing:
tf = hp.teacher_forcing
else:
tf = cos_decay(max(global_step - hp.teacher_forcing_start_steps, 0), hp.teacher_forcing_steps)
# run the model
post_pred, pre_pred, stop_pred, alignment, spkrs_pred, enc_output = model(src, src_len, trg_mel, trg_len, spkrs,
langs, tf)
# evaluate loss function
post_trg = trg_lin if hp.predict_linear else trg_mel
classifier = model._reversal_classifier if hp.reversal_classifier else None
loss, batch_losses = criterion(src_len, trg_len, pre_pred, trg_mel, post_pred, post_trg, stop_pred, stop_trg,
alignment, spkrs, spkrs_pred, enc_output, classifier)
# evaluate adversarial classifier accuracy, if present
if hp.reversal_classifier:
input_mask = lengths_to_mask(src_len)
trg_spkrs = torch.zeros_like(input_mask, dtype=torch.int64)
for s in range(hp.speaker_number):
speaker_mask = (spkrs == s)
trg_spkrs[speaker_mask] = s
matches = (trg_spkrs == torch.argmax(torch.nn.functional.softmax(spkrs_pred, dim=-1), dim=-1))
matches[~input_mask] = False
cla = torch.sum(matches).item() / torch.sum(input_mask).item()
# calculate gradient accumulated loss
if hp.gradient_accumulation and hp.gradient_accumulation > 1:
loss = loss / hp.gradient_accumulation
loss.backward()
# apply loss once per hp.gradient_accumulation group
if not hp.gradient_accumulation or hp.gradient_accumulation < 2 or (i + 1) % hp.gradient_accumulation == 0:
gradient = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.gradient_clipping)
optimizer.step()
optimizer.zero_grad()
# log training progress
if epoch >= logging_start_epoch:
Logger.training(global_step, batch_losses, gradient, learning_rate, time.time() - start_time, cla)
start_time = time.time()
# update criterion states (params and decay of the loss and so on ...)
criterion.update_states()
done += 1
def evaluate(epoch, data, model, criterion):
"""Main evaluation procedure.
Arguments:
epoch -- current epoch
data -- DataLoader which can provide validation batches
model -- model to be evaluated
criterion -- instance of loss function to measure performance
"""
model.eval()
# initialize counters, etc.
mcd, mcd_count = 0, 0
cla, cla_count = 0, 0
eval_losses = {}
# loop through epoch batches
with torch.no_grad():
for _, batch in enumerate(data):
# parse batch
batch = list(map(to_gpu, batch))
src, src_len, trg_mel, trg_lin, trg_len, stop_trg, spkrs, langs = batch
# run the model (twice, with and without teacher forcing)
post_pred, pre_pred, stop_pred, alignment, spkrs_pred, enc_output = model(src, src_len, trg_mel, trg_len,
spkrs, langs, 1.0)
post_pred_0, _, stop_pred_0, alignment_0, _, _ = model(src, src_len, trg_mel, trg_len, spkrs, langs, 0.0)
stop_pred_probs = torch.sigmoid(stop_pred_0)
# evaluate loss function
post_trg = trg_lin if hp.predict_linear else trg_mel
classifier = model._reversal_classifier if hp.reversal_classifier else None
loss, batch_losses = criterion(src_len, trg_len, pre_pred, trg_mel, post_pred, post_trg, stop_pred,
stop_trg, alignment, spkrs, spkrs_pred, enc_output, classifier)
# compute mel cepstral distorsion
for j, (gen, ref, stop) in enumerate(zip(post_pred_0, trg_mel, stop_pred_probs)):
stop_idxes = np.where(stop.cpu().numpy() > 0.5)[0]
stop_idx = min(np.min(stop_idxes) + hp.stop_frames, gen.size()[1]) if len(stop_idxes) > 0 else \
gen.size()[1]
gen = gen[:, :stop_idx].data.cpu().numpy()
ref = ref[:, :trg_len[j]].data.cpu().numpy()
if hp.normalize_spectrogram:
gen = audio.denormalize_spectrogram(gen, not hp.predict_linear)
ref = audio.denormalize_spectrogram(ref, True)
if hp.predict_linear: gen = audio.linear_to_mel(gen)
mcd = (mcd_count * mcd + audio.mel_cepstral_distorision(gen, ref, 'dtw')) / (mcd_count + 1)
mcd_count += 1
# compute adversarial classifier accuracy
if hp.reversal_classifier:
input_mask = lengths_to_mask(src_len)
trg_spkrs = torch.zeros_like(input_mask, dtype=torch.int64)
for s in range(hp.speaker_number):
speaker_mask = (spkrs == s)
trg_spkrs[speaker_mask] = s
matches = (trg_spkrs == torch.argmax(torch.nn.functional.softmax(spkrs_pred, dim=-1), dim=-1))
matches[~input_mask] = False
cla = (cla_count * cla + torch.sum(matches).item() / torch.sum(input_mask).item()) / (cla_count + 1)
cla_count += 1
# add batch losses to epoch losses
for k, v in batch_losses.items():
eval_losses[k] = v + eval_losses[k] if k in eval_losses else v
# normalize loss per batch
for k in eval_losses.keys():
eval_losses[k] /= len(data)
# log evaluation
Logger.evaluation(epoch + 1, eval_losses, mcd, src_len, trg_len, src, post_trg, post_pred, post_pred_0,
stop_pred_probs, stop_trg, alignment_0, cla)
return sum(eval_losses.values())
class DataParallelPassthrough(torch.nn.DataParallel):
"""Simple wrapper around DataParallel."""
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
if __name__ == '__main__':
import argparse
import os
torch.set_num_threads(12)
parser = argparse.ArgumentParser()
parser.add_argument("--base_directory", type=str, default=".", help="Base directory of the project.")
parser.add_argument("--checkpoint", type=str, default=None, help="Name of the initial checkpoint.")
parser.add_argument("--checkpoint_root", type=str, default="checkpoints", help="Base directory of checkpoints.")
parser.add_argument("--data_root", type=str, default="data", help="Base directory of datasets.")
parser.add_argument("--flush_seconds", type=int, default=60,
help="How often to flush pending summaries to tensorboard.")
parser.add_argument('--hyper_parameters', type=str, default=None, help="Name of the hyperparameters file.")
parser.add_argument('--logging_start', type=int, default=1, help="First epoch to be logged")
parser.add_argument('--max_gpus', type=int, default=2, help="Maximal number of GPUs of the local machine to use.")
parser.add_argument('--loader_workers', type=int, default=0, help="Number of subprocesses to use for data loading.")
parser.add_argument('--fine_tuning', action='store_true', help="Fine tune checkpoint to possibly unseen language.")
parser.add_argument('--log_high_loss', action='store_true', help="Log batch details for high loss values.")
args = parser.parse_args()
# set up seeds and the target torch device
np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# prepare directory for checkpoints
checkpoint_dir = os.path.join(args.base_directory, args.checkpoint_root)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# load checkpoint (dict) with saved hyper-parameters (let some of them be overwritten because of fine-tuning)
if args.checkpoint:
checkpoint = os.path.join(checkpoint_dir, args.checkpoint)
checkpoint_state = torch.load(checkpoint, map_location='cpu')
hp.load_state_dict(checkpoint_state['parameters'])
used_input_characters = hp.phonemes if hp.use_phonemes else hp.characters
used_languages = hp.languages
used_speakers = hp.unique_speakers
# load hyperparameters
if args.hyper_parameters is not None:
hp_path = os.path.join(args.base_directory, 'params', f'{args.hyper_parameters}.json')
hp.load(hp_path)
if args.fine_tuning:
new_input_characters = hp.phonemes if hp.use_phonemes else hp.characters
extra_input_characters = [x for x in new_input_characters if x not in used_input_characters]
ordered_new_input_characters = used_input_characters + ''.join(extra_input_characters)
if hp.use_phonemes:
hp.phonemes = ordered_new_input_characters
else:
hp.characters = ordered_new_input_characters
# load dataset
dataset = TextToSpeechDatasetCollection(os.path.join(args.data_root, hp.dataset))
if hp.multi_language and hp.balanced_sampling and hp.perfect_sampling:
dp_devices = args.max_gpus if hp.parallelization and torch.cuda.device_count() > 1 else 1
train_sampler = PerfectBatchSampler(dataset.train, hp.languages, hp.batch_size,
data_parallel_devices=dp_devices, shuffle=True, #
drop_last=True)
train_data = DataLoader(dataset.train, batch_sampler=train_sampler, pin_memory=False, #
collate_fn=TextToSpeechCollate(False), #
num_workers=args.loader_workers)
eval_sampler = PerfectBatchSampler(dataset.dev, hp.languages, hp.batch_size, #
data_parallel_devices=dp_devices, #
shuffle=True)
eval_data = DataLoader(dataset.dev, batch_sampler=eval_sampler, pin_memory=False, #
collate_fn=TextToSpeechCollate(False), #
num_workers=args.loader_workers)
else:
sampler = RandomImbalancedSampler(dataset.train) if hp.multi_language and hp.balanced_sampling else None
train_data = DataLoader(dataset.train, batch_size=hp.batch_size, drop_last=True,
shuffle=(not hp.multi_language or not hp.balanced_sampling), sampler=sampler,
pin_memory=False, collate_fn=TextToSpeechCollate(True), num_workers=args.loader_workers)
eval_data = DataLoader(dataset.dev, batch_size=hp.batch_size, drop_last=False, shuffle=False, pin_memory=False,
collate_fn=TextToSpeechCollate(True), num_workers=args.loader_workers)
# find out number of unique speakers and languages
hp.speaker_number = 0 if not hp.multi_speaker else dataset.train.get_num_speakers()
hp.language_number = 0 if not hp.multi_language else len(hp.languages)
# save all found speakers to hyper parameters
if hp.multi_speaker and not args.checkpoint:
hp.unique_speakers = dataset.train.unique_speakers
# acquire dataset-dependent constants, these should probably be the same while going from checkpoint
if not args.checkpoint:
# compute per-channel constants for spectrogram normalization
hp.mel_normalize_mean, hp.mel_normalize_variance = dataset.train.get_normalization_constants(True)
if hp.predict_linear:
hp.lin_normalize_mean, hp.lin_normalize_variance = dataset.train.get_normalization_constants(False)
if args.fine_tuning:
if hp.use_phonemes:
hp.phonemes = used_input_characters
else:
hp.characters = used_input_characters
# instantiate model
if torch.cuda.is_available():
model = Tacotron().cuda()
if hp.parallelization and args.max_gpus > 1 and torch.cuda.device_count() > 1:
model = DataParallelPassthrough(model, device_ids=list(range(args.max_gpus)))
else:
model = Tacotron()
# instantiate optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay)
if hp.encoder_optimizer:
encoder_params = list(model._encoder.parameters())
other_params = list(model._decoder.parameters()) + list(model._postnet.parameters()) + list(
model._prenet.parameters()) + list(model._embedding.parameters()) + list(model._attention.parameters())
if hp.reversal_classifier:
other_params += list(model._reversal_classifier.parameters())
optimizer = torch.optim.Adam(
[{'params': other_params}, {'params': encoder_params, 'lr': hp.learning_rate_encoder}],
lr=hp.learning_rate, weight_decay=hp.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, hp.learning_rate_decay_each // len(train_data),
gamma=hp.learning_rate_decay)
criterion = TacotronLoss(hp.guided_attention_steps, hp.guided_attention_toleration, hp.guided_attention_gain)
# load model weights and optimizer, scheduler states from checkpoint state dictionary
initial_epoch = 0
if args.checkpoint:
# load model state dict (can be imcomplete if pretraining part of the model)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in checkpoint_state['model'].items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
if checkpoint_state['epoch']:
initial_epoch = checkpoint_state['epoch'] + 1
if checkpoint_state['optimizer']:
optimizer.load_state_dict(checkpoint_state['optimizer'])
if checkpoint_state['scheduler']:
scheduler.load_state_dict(checkpoint_state['scheduler'])
if checkpoint_state['criterion']:
criterion.load_state_dict(checkpoint_state['criterion'])
if args.fine_tuning:
# make speaker and language embeddings constant
hp.embedding_type = "constant"
if hp.multi_speaker:
embedding = model._decoder._speaker_embedding.weight.mean(dim=0)
model._decoder._speaker_embedding = ConstantEmbedding(embedding)
if hp.multi_language:
embedding = model._decoder._language_embedding.weight.mean(dim=0)
model._decoder._language_embedding = ConstantEmbedding(embedding)
# enlarge the input embedding to fit all new characters
if hp.use_phonemes:
hp.phonemes = ordered_new_input_characters
else:
hp.characters = ordered_new_input_characters
num_news = len(extra_input_characters)
input_embedding = model._embedding.weight
with torch.no_grad():
input_embedding = torch.nn.functional.pad(input_embedding, (0, 0, 0, num_news))
torch.nn.init.xavier_uniform_(input_embedding[-num_news:, :])
model._embedding = torch.nn.Embedding.from_pretrained(input_embedding, padding_idx=0, freeze=False)
# initialize logger
log_dir = os.path.join(args.base_directory, "logs",
f'{hp.version}-{datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")}')
Logger.initialize(log_dir, args.flush_seconds)
print("Log directory", log_dir)
# training loop
best_eval = float('inf')
for epoch in range(initial_epoch, hp.epochs):
print("Starting epoch:", epoch + 1)
train(args.logging_start, epoch, train_data, model, criterion, optimizer)
if hp.learning_rate_decay_start - hp.learning_rate_decay_each < epoch * len(train_data):
scheduler.step()
eval_loss = evaluate(epoch, eval_data, model, criterion)
if (epoch + 1) % hp.checkpoint_each_epochs == 0:
# save checkpoint together with hyper-parameters, optimizer and scheduler states
checkpoint_file = f'{checkpoint_dir}/{hp.version}-epoch_{int(epoch + 1):03d}-loss_{eval_loss:2.4f}'
state_dict = {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(),
'scheduler' : scheduler.state_dict(), 'parameters': hp.state_dict(),
'criterion' : criterion.state_dict()}
torch.save(state_dict, checkpoint_file)