Skip to content

Commit dc6ba3d

Browse files
committed
add a new direction loss for the velocities, proposed by a group of researchers out of Wuhan China
1 parent 598d08e commit dc6ba3d

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,12 @@ trainer()
142142
url = {https://api.semanticscholar.org/CorpusID:270878436}
143143
}
144144
```
145+
146+
```bibtex
147+
@inproceedings{Yao2024FasterDiTTF,
148+
title = {FasterDiT: Towards Faster Diffusion Transformers Training without Architecture Modification},
149+
author = {Jingfeng Yao and Wang Cheng and Wenyu Liu and Xinggang Wang},
150+
year = {2024},
151+
url = {https://api.semanticscholar.org/CorpusID:273346237}
152+
}
153+
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.1.10"
3+
version = "0.1.11"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,23 @@ class MSELoss(Module):
116116
def forward(self, pred, target, **kwargs):
117117
return F.mse_loss(pred, target)
118118

119+
class MSEAndDirectionLoss(Module):
120+
"""
121+
Figure 7 - https://arxiv.org/abs/2410.10356
122+
"""
123+
124+
def __init__(self, cosine_sim_dim: int = 1):
125+
super().__init__()
126+
assert cosine_sim_dim > 0, 'cannot be batch dimension'
127+
self.cosine_sim_dim = cosine_sim_dim
128+
129+
def forward(self, pred, target, **kwargs):
130+
mse_loss = F.mse_loss(pred, target)
131+
132+
direction_loss = (1. - F.cosine_similarity(pred, target, dim = self.cosine_sim_dim)).mean()
133+
134+
return mse_loss + direction_loss
135+
119136
# loss breakdown
120137

121138
LossBreakdown = namedtuple('LossBreakdown', ['total', 'main', 'data_match', 'velocity_match'])
@@ -135,6 +152,7 @@ def __init__(
135152
predict: Literal['flow', 'noise'] = 'flow',
136153
loss_fn: Literal[
137154
'mse',
155+
'mse_and_direction',
138156
'pseudo_huber',
139157
'pseudo_huber_with_lpips'
140158
] | Module = 'mse',
@@ -179,6 +197,9 @@ def __init__(
179197
if loss_fn == 'mse':
180198
loss_fn = MSELoss()
181199

200+
elif loss_fn == 'mse_and_direction':
201+
loss_fn = MSEAndDirectionLoss(**loss_fn_kwargs)
202+
182203
elif loss_fn == 'pseudo_huber':
183204
assert predict == 'flow'
184205

0 commit comments

Comments
 (0)