Skip to content

Commit 0acfd8e

Browse files
committed
implement power low compression loss
1 parent 36b59da commit 0acfd8e

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

Diff for: config/default.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ train:
3939
final: 0.05
4040
summary_interval: 1
4141
checkpoint_interval: 1000
42+
complex_loss_ratio: 0.1 # the lambda in power-law compression loss computation
4243
---
4344
log:
4445
chkpt_dir: 'chkpt'

Diff for: utils/train.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,16 @@ def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp,
6464
mask = model(mixed_mag, dvec)
6565
output = mixed_mag * mask
6666

67-
# output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power)
68-
# target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power)
69-
loss = criterion(output, target_mag)
67+
# Power-law compression
68+
magnitude_loss = criterion(
69+
torch.pow(torch.abs(output), hp.audio.power),
70+
torch.pow(torch.abs(target_mag), hp.audio.power),
71+
)
72+
complex_loss = criterion(
73+
torch.pow(torch.clamp(output, min=0.0), hp.audio.power),
74+
torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power),
75+
)
76+
loss = magnitude_loss + complex_loss * hp.train.complex_loss_ratio
7077

7178
optimizer.zero_grad()
7279
loss.backward()

0 commit comments

Comments
 (0)