33from copy import deepcopy
44
55import torch
6+ from torch import nn
67from torch .nn import Module
78import torch .nn .functional as F
89
910from torchdiffeq import odeint
1011
12+ import torchvision
13+ from torchvision .models import VGG16_Weights
14+
15+ from einops import reduce , rearrange
16+
1117# helpers
1218
1319def exists (v ):
@@ -22,6 +28,61 @@ def append_dims(t, ndims):
2228 shape = t .shape
2329 return t .reshape (* shape , * ((1 ,) * ndims ))
2430
31+ # losses
32+
33+ class LPIPSLoss (Module ):
34+ def __init__ (
35+ self ,
36+ vgg : Module | None = None ,
37+ vgg_weights : VGG16_Weights = VGG16_Weights .DEFAULT ,
38+ ):
39+ super ().__init__ ()
40+
41+ if not exists (vgg ):
42+ vgg = torchvision .models .vgg16 (weights = vgg_weights )
43+ vgg .classifier = nn .Sequential (* vgg .classifier [:- 2 ])
44+
45+ self .vgg = vgg
46+
47+ def forward (self , pred_data , data , reduction = 'mean' ):
48+ embed = self .vgg (data )
49+ pred_embed = self .vgg (pred_data )
50+ loss = F .mse_loss (embed , pred_embed , reduction = reduction )
51+
52+ if reduction == 'none' :
53+ loss = reduce (loss , 'b ... -> b' , 'mean' )
54+
55+ return loss
56+
57+ class PseudoHuberLoss (Module ):
58+ def __init__ (self , data_dim : int ):
59+ super ().__init__ ()
60+ self .data_dim = data_dim
61+
62+ def forward (self , pred , target , reduction = 'mean' , ** kwargs ):
63+ c = .00054 * self .data_dim
64+ loss = (F .mse_loss (pred , target , reduction = reduction ) + c * c ).sqrt () - c
65+ return loss
66+
67+ class PseudoHuberLossWithLPIPS (Module ):
68+ def __init__ (self , data_dim : int , lpips_kwargs : dict = dict ()):
69+ super ().__init__ ()
70+ self .pseudo_huber = PseudoHuberLoss (data_dim )
71+ self .lpips = LPIPSLoss (** lpips_kwargs )
72+
73+ def forward (self , pred_flow , target_flow , * , times , data ):
74+ huber_loss = self .pseudo_huber (pred_flow , target_flow , reduction = 'none' )
75+
76+ pred_data = pred_flow * times
77+ lpips_loss = self .lpips (data , pred_data , reduction = 'none' )
78+
79+ time_weighted_loss = huber_loss * (1 - times ) + lpips_loss * (1. / times .clamp (min = 1e-2 ))
80+ return time_weighted_loss .mean ()
81+
82+ class MSELoss (Module ):
83+ def forward (self , pred , target , ** kwargs ):
84+ return F .mse_loss (pred , target )
85+
2586# main class
2687
2788class RectifiedFlow (Module ):
@@ -34,19 +95,33 @@ def __init__(
3495 rtol = 1e-5 ,
3596 method = 'midpoint'
3697 ),
37- loss_type : Literal [
38- 'mse' ,
39- 'pseudo_huber'
40- ] = 'mse' ,
98+ loss_fn : Literal ['mse' , 'pseudo_huber' ] | Module = 'mse' ,
99+ loss_fn_kwargs : dict = dict (),
41100 data_shape : Tuple [int , ...] | None = None ,
42101 ):
43102 super ().__init__ ()
44103 self .model = model
45104 self .time_cond_kwarg = time_cond_kwarg # whether the model is to be conditioned on the times
46105
47- # loss type
106+ # loss fn
107+
108+ if isinstance (loss_fn , Module ):
109+ loss_fn = loss_fn
48110
49- self .loss_type = loss_type
111+ elif loss_fn == 'mse' :
112+ loss_fn = MSELoss ()
113+
114+ elif loss_fn == 'pseudo_huber' :
115+ # section 4.2 of https://arxiv.org/abs/2405.20320v1
116+ loss_fn = PseudoHuberLoss (** loss_fn_kwargs )
117+
118+ elif loss_fn == 'pseudo_huber_with_lpips' :
119+ loss_fn = PseudoHuberLossWithLPIPS (** loss_fn_kwargs )
120+
121+ else :
122+ raise ValueError (f'unkwown loss function { loss_fn } ' )
123+
124+ self .loss_fn = loss_fn
50125
51126 # sampling
52127
@@ -135,17 +210,8 @@ def forward(
135210 pred_flow = self .model (noised , ** model_kwargs )
136211
137212 # loss
138- # section 4.2 of https://arxiv.org/abs/2405.20320v1
139-
140- if self .loss_type == 'mse' :
141- loss = F .mse_loss (pred_flow , flow )
142213
143- elif self .loss_type == 'pseudo_huber' :
144- c = .00054 * data_shape [0 ]
145- loss = (F .mse_loss (pred_flow , flow ) + c ** 2 ).sqrt () - c
146-
147- else :
148- raise ValueError (f'unrecognized loss type { self .loss_type } ' )
214+ loss = self .loss_fn (pred_flow , flow , times = times , data = data )
149215
150216 return loss
151217
0 commit comments