diff --git a/pyproject.toml b/pyproject.toml index ce665a0..0e64160 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rectified-flow-pytorch" -version = "0.0.24" +version = "0.0.25" description = "Rectified Flow in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/rectified_flow_pytorch/rectified_flow.py b/rectified_flow_pytorch/rectified_flow.py index 875a2c9..14e92fc 100644 --- a/rectified_flow_pytorch/rectified_flow.py +++ b/rectified_flow_pytorch/rectified_flow.py @@ -4,6 +4,7 @@ from typing import Tuple, List, Literal, Callable import torch +from torch import Tensor from torch import nn, pi, from_numpy from torch.nn import Module, ModuleList import torch.nn.functional as F @@ -263,7 +264,8 @@ def ode_fn(t, x): def forward( self, data, - noise = None, + noise: Tensor | None = None, + return_loss_breakdown = False, **model_kwargs ): batch, *data_shape = data.shape @@ -337,21 +339,28 @@ def get_noised_and_flows(model, t): # losses - loss = self.loss_fn(pred_flow, flow, pred_data = pred_data, times = times, data = data) + main_flow_loss = self.loss_fn(pred_flow, flow, pred_data = pred_data, times = times, data = data) + + consistency_loss = data_match_loss = velocity_match_loss = 0. if self.use_consistency: - # add velocity consistency loss from consistency fm paper - eq (6) in https://arxiv.org/html/2407.02398v1 + # consistency losses from consistency fm paper - eq (6) in https://arxiv.org/html/2407.02398v1 - α = self.consistency_velocity_match_alpha + data_match_loss = F.mse_loss(pred_data, ema_pred_data) + velocity_match_loss = F.mse_loss(pred_flow, ema_pred_flow) - consistency_loss = ( - F.mse_loss(pred_data, ema_pred_data) + - α * F.mse_loss(pred_flow, ema_pred_flow) - ) + consistency_loss = data_match_loss + velocity_match_loss * self.consistency_velocity_match_alpha - loss = loss + consistency_loss * self.consistency_loss_weight + # total loss - return loss + total_loss = main_flow_loss + consistency_loss * self.consistency_loss_weight + + if not return_loss_breakdown: + return total_loss + + # loss breakdown + + return loss, (main_flow_loss, data_match_loss, velocity_match_loss) # reflow wrapper