Skip to content

Commit 251018b

Browse files
committed
incorporate pr maum-ai#18 [WIP] implement power low compression loss
1 parent 3d70627 commit 251018b

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

Diff for: utils/train.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,20 @@ def train(args, pt_dir, chkpt_path, trainloader, testloader, writer, logger, hp,
6666

6767
# output = torch.pow(torch.clamp(output, min=0.0), hp.audio.power)
6868
# target_mag = torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power)
69-
loss = criterion(output, target_mag)
69+
if hp.train.complex_loss_ratio > 0:
70+
# Power-law compression
71+
magnitude_loss = criterion(
72+
torch.pow(torch.abs(output), hp.audio.power),
73+
torch.pow(torch.abs(target_mag), hp.audio.power),
74+
)
75+
complex_loss = criterion(
76+
torch.pow(torch.clamp(output, min=0.0), hp.audio.power),
77+
torch.pow(torch.clamp(target_mag, min=0.0), hp.audio.power),
78+
)
79+
loss = magnitude_loss + complex_loss * hp.train.complex_loss_ratio
80+
81+
else:
82+
loss = criterion(output, target_mag)
7083

7184
optimizer.zero_grad()
7285
loss.backward()

0 commit comments

Comments
 (0)