Skip to content

Commit c05f426

Browse files
committed
make sure one can do predict epsilon (noise) objective
1 parent ed7f2cd commit c05f426

File tree

2 files changed

+67
-24
lines changed

2 files changed

+67
-24
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.29"
3+
version = "0.1.0"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def forward(self, pred, target, **kwargs):
118118

119119
# loss breakdown
120120

121-
LossBreakdown = namedtuple('LossBreakdown', ['total', 'flow', 'data_match', 'velocity_match'])
121+
LossBreakdown = namedtuple('LossBreakdown', ['total', 'main', 'data_match', 'velocity_match'])
122122

123123
# main class
124124

@@ -132,6 +132,7 @@ def __init__(
132132
rtol = 1e-5,
133133
method = 'midpoint'
134134
),
135+
predict: Literal['flow', 'noise'] = 'flow',
135136
loss_fn: Literal[
136137
'mse',
137138
'pseudo_huber',
@@ -161,16 +162,24 @@ def __init__(
161162
self.model = model
162163
self.time_cond_kwarg = time_cond_kwarg # whether the model is to be conditioned on the times
163164

165+
# objective - either flow or noise (proposed by Esser / Rombach et al in SD3)
166+
167+
self.predict = predict
168+
164169
# loss fn
165170

166171
if loss_fn == 'mse':
167172
loss_fn = MSELoss()
168173

169174
elif loss_fn == 'pseudo_huber':
175+
assert predict == 'flow'
176+
170177
# section 4.2 of https://arxiv.org/abs/2405.20320v1
171178
loss_fn = PseudoHuberLoss(**loss_fn_kwargs)
172179

173180
elif loss_fn == 'pseudo_huber_with_lpips':
181+
assert predict == 'flow'
182+
174183
loss_fn = PseudoHuberLossWithLPIPS(**loss_fn_kwargs)
175184

176185
elif not isinstance(loss_fn, Module):
@@ -223,6 +232,44 @@ def __init__(
223232
def device(self):
224233
return next(self.model.parameters()).device
225234

235+
def predict_flow(self, model: Module, noised, *, times):
236+
"""
237+
returns the model output as well as the derived flow, depending on the `predict` objective
238+
"""
239+
240+
batch = noised.shape[0]
241+
242+
# prepare maybe time conditioning for model
243+
244+
model_kwargs = dict()
245+
time_kwarg = self.time_cond_kwarg
246+
247+
if exists(time_kwarg):
248+
times = rearrange(times, '... -> (...)')
249+
250+
if times.numel() == 1:
251+
times = repeat(times, '1 -> b', b = batch)
252+
253+
model_kwargs.update(**{time_kwarg: times})
254+
255+
output = self.model(noised, **model_kwargs)
256+
257+
# depending on objective, derive flow
258+
259+
if self.predict == 'flow':
260+
flow = output
261+
262+
elif self.predict == 'noise':
263+
noise = output
264+
padded_times = append_dims(times, noised.ndim - 1)
265+
266+
flow = (noised - noise) / padded_times.clamp(min = 1e-2)
267+
268+
else:
269+
raise ValueError(f'unknown objective {self.predict}')
270+
271+
return output, flow
272+
226273
@torch.no_grad()
227274
def sample(
228275
self,
@@ -245,13 +292,8 @@ def sample(
245292
assert exists(data_shape), 'you need to either pass in a `data_shape` or have trained at least with one forward'
246293

247294
def ode_fn(t, x):
248-
time_kwarg = self.time_cond_kwarg
249-
250-
if exists(time_kwarg):
251-
t = repeat(t, '-> b', b = x.shape[0])
252-
model_kwargs.update(**{time_kwarg: t})
253-
254-
return model(x, **model_kwargs)
295+
_, flow = self.predict_flow(model, x, times = t, **model_kwargs)
296+
return flow
255297

256298
# start with random gaussian noise - y0
257299

@@ -317,39 +359,40 @@ def get_noised_and_flows(model, t):
317359

318360
noised = t * data + (1. - t) * noise
319361

320-
# prepare maybe time conditioning for model
321-
322-
time_kwarg = self.time_cond_kwarg
323-
324-
if exists(time_kwarg):
325-
flat_time = rearrange(t, '... -> (...)')
326-
model_kwargs.update(**{time_kwarg: flat_time})
327-
328362
# the model predicts the flow from the noised data
329363

330364
flow = data - noise
331365

332-
pred_flow = model(noised, **model_kwargs)
366+
model_output, pred_flow = self.predict_flow(model, noised, times = t)
333367

334368
# predicted data will be the noised xt + flow * (1. - t)
335369

336370
pred_data = noised + pred_flow * (1. - t)
337371

338-
return flow, pred_flow, pred_data
372+
return model_output, flow, pred_flow, pred_data
339373

340374
# getting flow and pred flow for main model
341375

342-
flow, pred_flow, pred_data = get_noised_and_flows(self.model, padded_times)
376+
output, flow, pred_flow, pred_data = get_noised_and_flows(self.model, padded_times)
343377

344378
# if using consistency loss, also need the ema model predicted flow
345379

346380
if self.use_consistency:
347381
delta_t = self.consistency_delta_time
348-
ema_flow, ema_pred_flow, ema_pred_data = get_noised_and_flows(self.ema_model, padded_times + delta_t)
382+
ema_output, ema_flow, ema_pred_flow, ema_pred_data = get_noised_and_flows(self.ema_model, padded_times + delta_t)
383+
384+
# determine target, depending on objective
385+
386+
if self.predict == 'flow':
387+
target = flow
388+
elif self.predict == 'noise':
389+
target = noise
390+
else:
391+
raise ValueError(f'unknown objective {self.predict}')
349392

350393
# losses
351394

352-
main_flow_loss = self.loss_fn(pred_flow, flow, pred_data = pred_data, times = times, data = data)
395+
main_loss = self.loss_fn(output, target, pred_data = pred_data, times = times, data = data)
353396

354397
consistency_loss = data_match_loss = velocity_match_loss = 0.
355398

@@ -363,14 +406,14 @@ def get_noised_and_flows(model, t):
363406

364407
# total loss
365408

366-
total_loss = main_flow_loss + consistency_loss * self.consistency_loss_weight
409+
total_loss = main_loss + consistency_loss * self.consistency_loss_weight
367410

368411
if not return_loss_breakdown:
369412
return total_loss
370413

371414
# loss breakdown
372415

373-
return total_loss, LossBreakdown(total_loss, main_flow_loss, data_match_loss, velocity_match_loss)
416+
return total_loss, LossBreakdown(total_loss, main_loss, data_match_loss, velocity_match_loss)
374417

375418
# reflow wrapper
376419

0 commit comments

Comments
 (0)