Skip to content

Commit 1c86ea0

Browse files
author
T-Brain 정원진
committed
HDF5 bug fix
1 parent 14bfaad commit 1c86ea0

File tree

2 files changed

+36
-28
lines changed

2 files changed

+36
-28
lines changed

model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class CycleGAN(object):
88

99
def __init__(self, num_features, mode = 'train',
10-
log_dir = './log', model_name='tmp.ckpt', FEAT=True, gen_model='generator_gatedcnn'):
10+
log_dir = './log', model_name='tmp.ckpt', gen_model='generator_gatedcnn'):
1111

1212
self.num_features = num_features
1313
self.input_shape = [None, num_features, None] # [batch_size, num_features, num_frames]
@@ -133,11 +133,11 @@ def optimizer_initializer(self):
133133
self.discriminator_optimizer = tf.train.AdamOptimizer(learning_rate = self.discriminator_learning_rate, beta1 = 0.5).minimize(self.discriminator_loss, var_list = self.discriminator_vars)
134134
self.generator_optimizer = tf.train.AdamOptimizer(learning_rate = self.generator_learning_rate, beta1 = 0.5).minimize(self.generator_loss, var_list = self.generator_vars)
135135

136-
def train(self, input_A, input_B, lambda_cycle, lambda_identity, lambda_feat, generator_learning_rate, discriminator_learning_rate):
136+
def train(self, input_A, input_B, lambda_cycle, lambda_identity, generator_learning_rate, discriminator_learning_rate):
137137

138138
generation_A, generation_B, generator_loss, _, generator_summaries, generator_loss_A2B = self.sess.run(
139139
[self.generation_A, self.generation_B, self.generator_loss, self.generator_optimizer, self.generator_summaries, self.generator_loss_A2B], \
140-
feed_dict = {self.lambda_cycle: lambda_cycle, self.lambda_identity: lambda_identity, self.lambda_feat: lambda_feat, self.input_A_real: input_A, self.input_B_real: input_B, self.generator_learning_rate: generator_learning_rate})
140+
feed_dict = {self.lambda_cycle: lambda_cycle, self.lambda_identity: lambda_identity, self.input_A_real: input_A, self.input_B_real: input_B, self.generator_learning_rate: generator_learning_rate})
141141

142142
self.writer.add_summary(generator_summaries, self.train_step)
143143

train.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,32 @@ def train(train_A_dir, train_B_dir, model_dir, model_name, random_seed, validati
3434
start_time = time.time()
3535

3636
# Data load using HDF5, world vocoder has extra high complexity.(e.g., 1 song -> 5 min)
37-
if not os.path.exists(hdf5_path[0]):
38-
os.makedirs(hdf5_path[0])
39-
os.makedirs(hdf5_path[1])
40-
f0s_A, timeaxes_A, sps_A, aps_A, coded_sps_A = world_encode_data_toSave(num_mcep, hdf5_dir=hdf5_path[0],
41-
wav_dir=train_A_dir,
42-
sr=sampling_rate,
43-
frame_period=5.0,
44-
coded_dim24=24,
45-
coded_dim36=36)
46-
f0s_B, timeaxes_B, sps_B, aps_B, coded_sps_B = world_encode_data_toSave(num_mcep, hdf5_dir=hdf5_path[1],
47-
wav_dir=train_B_dir,
48-
sr=sampling_rate,
49-
frame_period=5.0,
50-
coded_dim24=24,
51-
coded_dim36=36)
37+
if hdf5_path[0] is None:
38+
wavs_A = load_wavs(wav_dir=train_A_dir, sr=sampling_rate)
39+
wavs_B = load_wavs(wav_dir=train_B_dir, sr=sampling_rate)
40+
f0s_A, timeaxes_A, sps_A, aps_A, coded_sps_A = world_encode_data(wavs=wavs_A, fs=sampling_rate,
41+
frame_period=frame_period, coded_dim=num_mcep)
42+
f0s_B, timeaxes_B, sps_B, aps_B, coded_sps_B = world_encode_data(wavs=wavs_B, fs=sampling_rate,
43+
frame_period=frame_period, coded_dim=num_mcep)
5244
else:
53-
f0s_A, timeaxes_A, sps_A, aps_A, coded_sps_A = world_encode_data_toLoad(num_mcep, hdf5_dir=hdf5_path[0])
54-
f0s_B, timeaxes_B, sps_B, aps_B, coded_sps_B = world_encode_data_toLoad(num_mcep, hdf5_dir=hdf5_path[1])
45+
if not os.path.exists(hdf5_path[0]):
46+
os.makedirs(hdf5_path[0])
47+
os.makedirs(hdf5_path[1])
48+
f0s_A, timeaxes_A, sps_A, aps_A, coded_sps_A = world_encode_data_toSave(num_mcep, hdf5_dir=hdf5_path[0],
49+
wav_dir=train_A_dir,
50+
sr=sampling_rate,
51+
frame_period=5.0,
52+
coded_dim24=24,
53+
coded_dim36=36)
54+
f0s_B, timeaxes_B, sps_B, aps_B, coded_sps_B = world_encode_data_toSave(num_mcep, hdf5_dir=hdf5_path[1],
55+
wav_dir=train_B_dir,
56+
sr=sampling_rate,
57+
frame_period=5.0,
58+
coded_dim24=24,
59+
coded_dim36=36)
60+
else:
61+
f0s_A, timeaxes_A, sps_A, aps_A, coded_sps_A = world_encode_data_toLoad(num_mcep, hdf5_dir=hdf5_path[0])
62+
f0s_B, timeaxes_B, sps_B, aps_B, coded_sps_B = world_encode_data_toLoad(num_mcep, hdf5_dir=hdf5_path[1])
5563

5664

5765
log_f0s_mean_A, log_f0s_std_A = logf0_statistics(f0s_A)
@@ -99,7 +107,7 @@ def train(train_A_dir, train_B_dir, model_dir, model_name, random_seed, validati
99107
# ---------------------------------------------- Data preprocessing ---------------------------------------------- #
100108

101109
# Model define
102-
model = CycleGAN(num_features = num_mcep, log_dir=tensorboard_log_dir, model_name=model_name, FEAT=FEAT, gen_model=gen_model)
110+
model = CycleGAN(num_features = num_mcep, log_dir=tensorboard_log_dir, model_name=model_name, gen_model=gen_model)
103111
# load model
104112
if os.path.exists(os.path.join(model_dir, (model_name+".index"))) == True:
105113
model.load(filepath=os.path.join(model_dir, model_name))
@@ -129,7 +137,7 @@ def train(train_A_dir, train_B_dir, model_dir, model_name, random_seed, validati
129137

130138
generator_loss, discriminator_loss, generator_loss_A2B = model.train\
131139
(input_A = dataset_A[start:end], input_B = dataset_B[start:end],
132-
lambda_cycle = lambda_cycle, lambda_identity = lambda_identity, lambda_feat = lambda_feat,
140+
lambda_cycle = lambda_cycle, lambda_identity = lambda_identity,
133141
generator_learning_rate = generator_learning_rate, discriminator_learning_rate = discriminator_learning_rate)
134142
model.summary()
135143

@@ -203,19 +211,19 @@ def train(train_A_dir, train_B_dir, model_dir, model_name, random_seed, validati
203211

204212
parser = argparse.ArgumentParser(description='Train CycleGAN-VC2 model')
205213

206-
train_A_dir_default = './data/vcc2016_training/SF1'
207-
train_B_dir_default = './data/vcc2016_training/TF2'
214+
train_A_dir_default = '/root/onejin/train/ma'
215+
train_B_dir_default = '/root/onejin/train/fe'
208216
model_dir_default = './model/sf1_tf2'
209217
model_name_default = 'sf1_tf2.ckpt'
210218
random_seed_default = 0
211-
validation_A_dir_default = './data/evaluation_all/SF1'
212-
validation_B_dir_default = './data/evaluation_all/TF2'
219+
validation_A_dir_default = '/root/onejin/test/ma'
220+
validation_B_dir_default = '/root/onejin/test/fe'
213221
output_dir_default = './validation_output'
214222
tensorboard_log_dir_default = './log'
215223
generator_model_default = 'CycleGAN-VC2'
216224
MCEPs_dim_default = 32
217-
hdf5A_path_defalut = './data/vcc2016_training/SF1_hdf5'
218-
hdf5B_path_defalut = './data/vcc2016_training/SF1_hdf5'
225+
hdf5A_path_defalut = None
226+
hdf5B_path_defalut = None
219227
lambda_cycle_defalut = 10.0
220228
lambda_identity_defalut = 5.0
221229

0 commit comments

Comments
 (0)