Skip to content

Commit

Permalink
🐛fix
Browse files Browse the repository at this point in the history
A bug fix that fixes an issue of training not working when user doesn't doesn't change the n steps value from 0 to positives ( or if decides to not use the metric )
  • Loading branch information
codename0og authored Dec 18, 2024
1 parent 6cfc58a commit b2891c2
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,15 @@

# Globals
global_step = 0

mini_batches = n_value

warmup_epochs = warmup_duration
warmup_enabled = use_warmup
warmup_completed = False

averaging_enabled = mini_batches > 0 # boolean


# -------------------------- Custom functions land in here --------------------------

Expand Down Expand Up @@ -302,9 +306,12 @@ def run(

if 'warmup_completed' not in globals():
warmup_completed = False

# Warmup init msg:
if rank == 0 and warmup_enabled:
print(f"////// Warmup enabled. Training will gradually increase learning rates over {warmup_epochs} epochs. //////")
print(f"////// WARMUP ENABLED: Training will gradually increase learning rates over: {warmup_epochs} epochs. //////")
# Averaging init msg:
if rank == 0 and averaging_enabled:
print(f"////// RUNNING AVG LOSS ENABLED: Training will log averaged losses every: {mini_batches} steps. //////")

if rank == 0:
writer = SummaryWriter(log_dir=experiment_dir)
Expand Down Expand Up @@ -369,7 +376,6 @@ def run(
).to(device)



net_d = MultiPeriodDiscriminator(config.model.use_spectral_norm).to(device)

optim_g = Ranger(
Expand All @@ -382,8 +388,8 @@ def run(
alpha=0.5,
k=6,
N_sma_threshhold=5, # 4 or 5 can be tried
use_gc=False,
gc_conv_only=False,
use_gc=True,
gc_conv_only=True,
gc_loc=False,
)
optim_d = Ranger(
Expand All @@ -396,8 +402,8 @@ def run(
alpha=0.5,
k=6,
N_sma_threshhold=5, # 4 or 5 can be tried
use_gc=False,
gc_conv_only=False,
use_gc=True,
gc_conv_only=True,
gc_loc=False,
)

Expand Down Expand Up @@ -635,11 +641,12 @@ def train_and_evaluate(
# Over N mini-batches loss averaging
N = mini_batches # Number of mini-batches after which the loss is logged
# Running loss init
running_loss_gen_all = 0.0
running_loss_gen_fm = 0.0
running_loss_gen_mel = 0.0
running_loss_gen_kl = 0.0
running_loss_disc_all = 0.0
if averaging_enabled:
running_loss_gen_all = 0.0
running_loss_gen_fm = 0.0
running_loss_gen_mel = 0.0
running_loss_gen_kl = 0.0
running_loss_disc_all = 0.0

with tqdm(total=len(train_loader), leave=False) as pbar:
for batch_idx, info in data_iterator:
Expand Down Expand Up @@ -689,7 +696,9 @@ def train_and_evaluate(
with autocast(enabled=False):
loss_disc = discriminator_loss(y_d_hat_r, y_d_hat_g) # loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)

running_loss_disc_all += loss_disc #.item() # For Discriminator
# Accumulate losses for discriminator
if averaging_enabled:
running_loss_disc_all += loss_disc #.item() # For Discriminator

# Backward and update for discs:

Expand Down Expand Up @@ -752,23 +761,25 @@ def train_and_evaluate(

global_step += 1
pbar.update(1)

# Accumulate losses for generator
running_loss_gen_all += loss_gen_all #.item() # For Generator - all
running_loss_gen_fm += loss_fm #.item() # For Generator - FM
running_loss_gen_mel += loss_mel #.item() # For Generator - MEL
running_loss_gen_kl += loss_kl #.item() # For Generator - KL
if averaging_enabled:
running_loss_gen_all += loss_gen_all #.item() # For Generator - all
running_loss_gen_fm += loss_fm #.item() # For Generator - FM
running_loss_gen_mel += loss_mel #.item() # For Generator - MEL
running_loss_gen_kl += loss_kl #.item() # For Generator - KL


# Logging of the averaged loss every N mini-batches
if rank == 0 and (batch_idx + 1) % N == 0:
# For Generator
if averaging_enabled and rank == 0 and (batch_idx + 1) % N == 0:
# For Generator:
avg_loss_gen_all = running_loss_gen_all / N
avg_loss_gen_fm = running_loss_gen_fm / N
avg_loss_gen_mel = running_loss_gen_mel / N
avg_loss_gen_kl = running_loss_gen_kl / N
# For Discriminator
# For Discriminator:
avg_loss_disc_all = running_loss_disc_all / N
# Logging
# Logging:
writer.add_scalar('Average_Loss/Generator_Avg_Total', avg_loss_gen_all, global_step)
writer.add_scalar('Average_Loss/Generator_Avg_FM', avg_loss_gen_fm, global_step)
writer.add_scalar('Average_Loss/Generator_Avg_MEL', avg_loss_gen_mel, global_step)
Expand Down

0 comments on commit b2891c2

Please sign in to comment.