3
3
from copy import deepcopy
4
4
5
5
import torch
6
+ from torch import nn
6
7
from torch .nn import Module
7
8
import torch .nn .functional as F
8
9
9
10
from torchdiffeq import odeint
10
11
12
+ import torchvision
13
+ from torchvision .models import VGG16_Weights
14
+
15
+ from einops import reduce , rearrange
16
+
11
17
# helpers
12
18
13
19
def exists (v ):
@@ -22,6 +28,61 @@ def append_dims(t, ndims):
22
28
shape = t .shape
23
29
return t .reshape (* shape , * ((1 ,) * ndims ))
24
30
31
+ # losses
32
+
33
+ class LPIPSLoss (Module ):
34
+ def __init__ (
35
+ self ,
36
+ vgg : Module | None = None ,
37
+ vgg_weights : VGG16_Weights = VGG16_Weights .DEFAULT ,
38
+ ):
39
+ super ().__init__ ()
40
+
41
+ if not exists (vgg ):
42
+ vgg = torchvision .models .vgg16 (weights = vgg_weights )
43
+ vgg .classifier = nn .Sequential (* vgg .classifier [:- 2 ])
44
+
45
+ self .vgg = vgg
46
+
47
+ def forward (self , pred_data , data , reduction = 'mean' ):
48
+ embed = self .vgg (data )
49
+ pred_embed = self .vgg (pred_data )
50
+ loss = F .mse_loss (embed , pred_embed , reduction = reduction )
51
+
52
+ if reduction == 'none' :
53
+ loss = reduce (loss , 'b ... -> b' , 'mean' )
54
+
55
+ return loss
56
+
57
+ class PseudoHuberLoss (Module ):
58
+ def __init__ (self , data_dim : int ):
59
+ super ().__init__ ()
60
+ self .data_dim = data_dim
61
+
62
+ def forward (self , pred , target , reduction = 'mean' , ** kwargs ):
63
+ c = .00054 * self .data_dim
64
+ loss = (F .mse_loss (pred , target , reduction = reduction ) + c * c ).sqrt () - c
65
+ return loss
66
+
67
+ class PseudoHuberLossWithLPIPS (Module ):
68
+ def __init__ (self , data_dim : int , lpips_kwargs : dict = dict ()):
69
+ super ().__init__ ()
70
+ self .pseudo_huber = PseudoHuberLoss (data_dim )
71
+ self .lpips = LPIPSLoss (** lpips_kwargs )
72
+
73
+ def forward (self , pred_flow , target_flow , * , times , data ):
74
+ huber_loss = self .pseudo_huber (pred_flow , target_flow , reduction = 'none' )
75
+
76
+ pred_data = pred_flow * times
77
+ lpips_loss = self .lpips (data , pred_data , reduction = 'none' )
78
+
79
+ time_weighted_loss = huber_loss * (1 - times ) + lpips_loss * (1. / times .clamp (min = 1e-2 ))
80
+ return time_weighted_loss .mean ()
81
+
82
+ class MSELoss (Module ):
83
+ def forward (self , pred , target , ** kwargs ):
84
+ return F .mse_loss (pred , target )
85
+
25
86
# main class
26
87
27
88
class RectifiedFlow (Module ):
@@ -34,19 +95,33 @@ def __init__(
34
95
rtol = 1e-5 ,
35
96
method = 'midpoint'
36
97
),
37
- loss_type : Literal [
38
- 'mse' ,
39
- 'pseudo_huber'
40
- ] = 'mse' ,
98
+ loss_fn : Literal ['mse' , 'pseudo_huber' ] | Module = 'mse' ,
99
+ loss_fn_kwargs : dict = dict (),
41
100
data_shape : Tuple [int , ...] | None = None ,
42
101
):
43
102
super ().__init__ ()
44
103
self .model = model
45
104
self .time_cond_kwarg = time_cond_kwarg # whether the model is to be conditioned on the times
46
105
47
- # loss type
106
+ # loss fn
107
+
108
+ if isinstance (loss_fn , Module ):
109
+ loss_fn = loss_fn
48
110
49
- self .loss_type = loss_type
111
+ elif loss_fn == 'mse' :
112
+ loss_fn = MSELoss ()
113
+
114
+ elif loss_fn == 'pseudo_huber' :
115
+ # section 4.2 of https://arxiv.org/abs/2405.20320v1
116
+ loss_fn = PseudoHuberLoss (** loss_fn_kwargs )
117
+
118
+ elif loss_fn == 'pseudo_huber_with_lpips' :
119
+ loss_fn = PseudoHuberLossWithLPIPS (** loss_fn_kwargs )
120
+
121
+ else :
122
+ raise ValueError (f'unkwown loss function { loss_fn } ' )
123
+
124
+ self .loss_fn = loss_fn
50
125
51
126
# sampling
52
127
@@ -135,17 +210,8 @@ def forward(
135
210
pred_flow = self .model (noised , ** model_kwargs )
136
211
137
212
# loss
138
- # section 4.2 of https://arxiv.org/abs/2405.20320v1
139
-
140
- if self .loss_type == 'mse' :
141
- loss = F .mse_loss (pred_flow , flow )
142
213
143
- elif self .loss_type == 'pseudo_huber' :
144
- c = .00054 * data_shape [0 ]
145
- loss = (F .mse_loss (pred_flow , flow ) + c ** 2 ).sqrt () - c
146
-
147
- else :
148
- raise ValueError (f'unrecognized loss type { self .loss_type } ' )
214
+ loss = self .loss_fn (pred_flow , flow , times = times , data = data )
149
215
150
216
return loss
151
217
0 commit comments