Skip to content

Commit 62fa31d

Browse files
committed
able to return the loss breakdown for consistency fm
1 parent 2598074 commit 62fa31d

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.24"
3+
version = "0.0.25"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Tuple, List, Literal, Callable
55

66
import torch
7+
from torch import Tensor
78
from torch import nn, pi, from_numpy
89
from torch.nn import Module, ModuleList
910
import torch.nn.functional as F
@@ -263,7 +264,8 @@ def ode_fn(t, x):
263264
def forward(
264265
self,
265266
data,
266-
noise = None,
267+
noise: Tensor | None = None,
268+
return_loss_breakdown = False,
267269
**model_kwargs
268270
):
269271
batch, *data_shape = data.shape
@@ -337,21 +339,28 @@ def get_noised_and_flows(model, t):
337339

338340
# losses
339341

340-
loss = self.loss_fn(pred_flow, flow, pred_data = pred_data, times = times, data = data)
342+
main_flow_loss = self.loss_fn(pred_flow, flow, pred_data = pred_data, times = times, data = data)
343+
344+
consistency_loss = data_match_loss = velocity_match_loss = 0.
341345

342346
if self.use_consistency:
343-
# add velocity consistency loss from consistency fm paper - eq (6) in https://arxiv.org/html/2407.02398v1
347+
# consistency losses from consistency fm paper - eq (6) in https://arxiv.org/html/2407.02398v1
344348

345-
α = self.consistency_velocity_match_alpha
349+
data_match_loss = F.mse_loss(pred_data, ema_pred_data)
350+
velocity_match_loss = F.mse_loss(pred_flow, ema_pred_flow)
346351

347-
consistency_loss = (
348-
F.mse_loss(pred_data, ema_pred_data) +
349-
α * F.mse_loss(pred_flow, ema_pred_flow)
350-
)
352+
consistency_loss = data_match_loss + velocity_match_loss * self.consistency_velocity_match_alpha
351353

352-
loss = loss + consistency_loss * self.consistency_loss_weight
354+
# total loss
353355

354-
return loss
356+
total_loss = main_flow_loss + consistency_loss * self.consistency_loss_weight
357+
358+
if not return_loss_breakdown:
359+
return total_loss
360+
361+
# loss breakdown
362+
363+
return loss, (main_flow_loss, data_match_loss, velocity_match_loss)
355364

356365
# reflow wrapper
357366

0 commit comments

Comments
 (0)