Skip to content

Commit 96c4dd9

Browse files
committed
add the weighted pseudo huber + lpips loss from a recent paper
1 parent 89584d7 commit 96c4dd9

File tree

2 files changed

+84
-17
lines changed

2 files changed

+84
-17
lines changed

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.4"
3+
version = "0.0.5"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }
@@ -26,6 +26,7 @@ dependencies = [
2626
'ema-pytorch>=0.5.1',
2727
'scipy',
2828
'torch>=2.0',
29+
'torchvision',
2930
'torchdiffeq',
3031
]
3132

rectified_flow_pytorch/rectified_flow.py

+82-16
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
from copy import deepcopy
44

55
import torch
6+
from torch import nn
67
from torch.nn import Module
78
import torch.nn.functional as F
89

910
from torchdiffeq import odeint
1011

12+
import torchvision
13+
from torchvision.models import VGG16_Weights
14+
15+
from einops import reduce, rearrange
16+
1117
# helpers
1218

1319
def exists(v):
@@ -22,6 +28,61 @@ def append_dims(t, ndims):
2228
shape = t.shape
2329
return t.reshape(*shape, *((1,) * ndims))
2430

31+
# losses
32+
33+
class LPIPSLoss(Module):
34+
def __init__(
35+
self,
36+
vgg: Module | None = None,
37+
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
38+
):
39+
super().__init__()
40+
41+
if not exists(vgg):
42+
vgg = torchvision.models.vgg16(weights = vgg_weights)
43+
vgg.classifier = nn.Sequential(*vgg.classifier[:-2])
44+
45+
self.vgg = vgg
46+
47+
def forward(self, pred_data, data, reduction = 'mean'):
48+
embed = self.vgg(data)
49+
pred_embed = self.vgg(pred_data)
50+
loss = F.mse_loss(embed, pred_embed, reduction = reduction)
51+
52+
if reduction == 'none':
53+
loss = reduce(loss, 'b ... -> b', 'mean')
54+
55+
return loss
56+
57+
class PseudoHuberLoss(Module):
58+
def __init__(self, data_dim: int):
59+
super().__init__()
60+
self.data_dim = data_dim
61+
62+
def forward(self, pred, target, reduction = 'mean', **kwargs):
63+
c = .00054 * self.data_dim
64+
loss = (F.mse_loss(pred, target, reduction = reduction) + c * c).sqrt() - c
65+
return loss
66+
67+
class PseudoHuberLossWithLPIPS(Module):
68+
def __init__(self, data_dim: int, lpips_kwargs: dict = dict()):
69+
super().__init__()
70+
self.pseudo_huber = PseudoHuberLoss(data_dim)
71+
self.lpips = LPIPSLoss(**lpips_kwargs)
72+
73+
def forward(self, pred_flow, target_flow, *, times, data):
74+
huber_loss = self.pseudo_huber(pred_flow, target_flow, reduction = 'none')
75+
76+
pred_data = pred_flow * times
77+
lpips_loss = self.lpips(data, pred_data, reduction = 'none')
78+
79+
time_weighted_loss = huber_loss * (1 - times) + lpips_loss * (1. / times.clamp(min = 1e-2))
80+
return time_weighted_loss.mean()
81+
82+
class MSELoss(Module):
83+
def forward(self, pred, target, **kwargs):
84+
return F.mse_loss(pred, target)
85+
2586
# main class
2687

2788
class RectifiedFlow(Module):
@@ -34,19 +95,33 @@ def __init__(
3495
rtol = 1e-5,
3596
method = 'midpoint'
3697
),
37-
loss_type: Literal[
38-
'mse',
39-
'pseudo_huber'
40-
] = 'mse',
98+
loss_fn: Literal['mse', 'pseudo_huber'] | Module = 'mse',
99+
loss_fn_kwargs: dict = dict(),
41100
data_shape: Tuple[int, ...] | None = None,
42101
):
43102
super().__init__()
44103
self.model = model
45104
self.time_cond_kwarg = time_cond_kwarg # whether the model is to be conditioned on the times
46105

47-
# loss type
106+
# loss fn
107+
108+
if isinstance(loss_fn, Module):
109+
loss_fn = loss_fn
48110

49-
self.loss_type = loss_type
111+
elif loss_fn == 'mse':
112+
loss_fn = MSELoss()
113+
114+
elif loss_fn == 'pseudo_huber':
115+
# section 4.2 of https://arxiv.org/abs/2405.20320v1
116+
loss_fn = PseudoHuberLoss(**loss_fn_kwargs)
117+
118+
elif loss_fn == 'pseudo_huber_with_lpips':
119+
loss_fn = PseudoHuberLossWithLPIPS(**loss_fn_kwargs)
120+
121+
else:
122+
raise ValueError(f'unkwown loss function {loss_fn}')
123+
124+
self.loss_fn = loss_fn
50125

51126
# sampling
52127

@@ -135,17 +210,8 @@ def forward(
135210
pred_flow = self.model(noised, **model_kwargs)
136211

137212
# loss
138-
# section 4.2 of https://arxiv.org/abs/2405.20320v1
139-
140-
if self.loss_type == 'mse':
141-
loss = F.mse_loss(pred_flow, flow)
142213

143-
elif self.loss_type == 'pseudo_huber':
144-
c = .00054 * data_shape[0]
145-
loss = (F.mse_loss(pred_flow, flow) + c ** 2).sqrt() - c
146-
147-
else:
148-
raise ValueError(f'unrecognized loss type {self.loss_type}')
214+
loss = self.loss_fn(pred_flow, flow, times = times, data = data)
149215

150216
return loss
151217

0 commit comments

Comments
 (0)