Skip to content

Commit

Permalink
able to return the loss breakdown for consistency fm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 18, 2024
1 parent 2598074 commit 62fa31d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rectified-flow-pytorch"
version = "0.0.24"
version = "0.0.25"
description = "Rectified Flow in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
29 changes: 19 additions & 10 deletions rectified_flow_pytorch/rectified_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple, List, Literal, Callable

import torch
from torch import Tensor
from torch import nn, pi, from_numpy
from torch.nn import Module, ModuleList
import torch.nn.functional as F
Expand Down Expand Up @@ -263,7 +264,8 @@ def ode_fn(t, x):
def forward(
self,
data,
noise = None,
noise: Tensor | None = None,
return_loss_breakdown = False,
**model_kwargs
):
batch, *data_shape = data.shape
Expand Down Expand Up @@ -337,21 +339,28 @@ def get_noised_and_flows(model, t):

# losses

loss = self.loss_fn(pred_flow, flow, pred_data = pred_data, times = times, data = data)
main_flow_loss = self.loss_fn(pred_flow, flow, pred_data = pred_data, times = times, data = data)

consistency_loss = data_match_loss = velocity_match_loss = 0.

if self.use_consistency:
# add velocity consistency loss from consistency fm paper - eq (6) in https://arxiv.org/html/2407.02398v1
# consistency losses from consistency fm paper - eq (6) in https://arxiv.org/html/2407.02398v1

α = self.consistency_velocity_match_alpha
data_match_loss = F.mse_loss(pred_data, ema_pred_data)
velocity_match_loss = F.mse_loss(pred_flow, ema_pred_flow)

consistency_loss = (
F.mse_loss(pred_data, ema_pred_data) +
α * F.mse_loss(pred_flow, ema_pred_flow)
)
consistency_loss = data_match_loss + velocity_match_loss * self.consistency_velocity_match_alpha

loss = loss + consistency_loss * self.consistency_loss_weight
# total loss

return loss
total_loss = main_flow_loss + consistency_loss * self.consistency_loss_weight

if not return_loss_breakdown:
return total_loss

# loss breakdown

return loss, (main_flow_loss, data_match_loss, velocity_match_loss)

# reflow wrapper

Expand Down

0 comments on commit 62fa31d

Please sign in to comment.