@@ -118,7 +118,7 @@ def forward(self, pred, target, **kwargs):
118
118
119
119
# loss breakdown
120
120
121
- LossBreakdown = namedtuple ('LossBreakdown' , ['total' , 'flow ' , 'data_match' , 'velocity_match' ])
121
+ LossBreakdown = namedtuple ('LossBreakdown' , ['total' , 'main ' , 'data_match' , 'velocity_match' ])
122
122
123
123
# main class
124
124
@@ -132,6 +132,7 @@ def __init__(
132
132
rtol = 1e-5 ,
133
133
method = 'midpoint'
134
134
),
135
+ predict : Literal ['flow' , 'noise' ] = 'flow' ,
135
136
loss_fn : Literal [
136
137
'mse' ,
137
138
'pseudo_huber' ,
@@ -161,16 +162,24 @@ def __init__(
161
162
self .model = model
162
163
self .time_cond_kwarg = time_cond_kwarg # whether the model is to be conditioned on the times
163
164
165
+ # objective - either flow or noise (proposed by Esser / Rombach et al in SD3)
166
+
167
+ self .predict = predict
168
+
164
169
# loss fn
165
170
166
171
if loss_fn == 'mse' :
167
172
loss_fn = MSELoss ()
168
173
169
174
elif loss_fn == 'pseudo_huber' :
175
+ assert predict == 'flow'
176
+
170
177
# section 4.2 of https://arxiv.org/abs/2405.20320v1
171
178
loss_fn = PseudoHuberLoss (** loss_fn_kwargs )
172
179
173
180
elif loss_fn == 'pseudo_huber_with_lpips' :
181
+ assert predict == 'flow'
182
+
174
183
loss_fn = PseudoHuberLossWithLPIPS (** loss_fn_kwargs )
175
184
176
185
elif not isinstance (loss_fn , Module ):
@@ -223,6 +232,44 @@ def __init__(
223
232
def device (self ):
224
233
return next (self .model .parameters ()).device
225
234
235
+ def predict_flow (self , model : Module , noised , * , times ):
236
+ """
237
+ returns the model output as well as the derived flow, depending on the `predict` objective
238
+ """
239
+
240
+ batch = noised .shape [0 ]
241
+
242
+ # prepare maybe time conditioning for model
243
+
244
+ model_kwargs = dict ()
245
+ time_kwarg = self .time_cond_kwarg
246
+
247
+ if exists (time_kwarg ):
248
+ times = rearrange (times , '... -> (...)' )
249
+
250
+ if times .numel () == 1 :
251
+ times = repeat (times , '1 -> b' , b = batch )
252
+
253
+ model_kwargs .update (** {time_kwarg : times })
254
+
255
+ output = self .model (noised , ** model_kwargs )
256
+
257
+ # depending on objective, derive flow
258
+
259
+ if self .predict == 'flow' :
260
+ flow = output
261
+
262
+ elif self .predict == 'noise' :
263
+ noise = output
264
+ padded_times = append_dims (times , noised .ndim - 1 )
265
+
266
+ flow = (noised - noise ) / padded_times .clamp (min = 1e-2 )
267
+
268
+ else :
269
+ raise ValueError (f'unknown objective { self .predict } ' )
270
+
271
+ return output , flow
272
+
226
273
@torch .no_grad ()
227
274
def sample (
228
275
self ,
@@ -245,13 +292,8 @@ def sample(
245
292
assert exists (data_shape ), 'you need to either pass in a `data_shape` or have trained at least with one forward'
246
293
247
294
def ode_fn (t , x ):
248
- time_kwarg = self .time_cond_kwarg
249
-
250
- if exists (time_kwarg ):
251
- t = repeat (t , '-> b' , b = x .shape [0 ])
252
- model_kwargs .update (** {time_kwarg : t })
253
-
254
- return model (x , ** model_kwargs )
295
+ _ , flow = self .predict_flow (model , x , times = t , ** model_kwargs )
296
+ return flow
255
297
256
298
# start with random gaussian noise - y0
257
299
@@ -317,39 +359,40 @@ def get_noised_and_flows(model, t):
317
359
318
360
noised = t * data + (1. - t ) * noise
319
361
320
- # prepare maybe time conditioning for model
321
-
322
- time_kwarg = self .time_cond_kwarg
323
-
324
- if exists (time_kwarg ):
325
- flat_time = rearrange (t , '... -> (...)' )
326
- model_kwargs .update (** {time_kwarg : flat_time })
327
-
328
362
# the model predicts the flow from the noised data
329
363
330
364
flow = data - noise
331
365
332
- pred_flow = model ( noised , ** model_kwargs )
366
+ model_output , pred_flow = self . predict_flow ( model , noised , times = t )
333
367
334
368
# predicted data will be the noised xt + flow * (1. - t)
335
369
336
370
pred_data = noised + pred_flow * (1. - t )
337
371
338
- return flow , pred_flow , pred_data
372
+ return model_output , flow , pred_flow , pred_data
339
373
340
374
# getting flow and pred flow for main model
341
375
342
- flow , pred_flow , pred_data = get_noised_and_flows (self .model , padded_times )
376
+ output , flow , pred_flow , pred_data = get_noised_and_flows (self .model , padded_times )
343
377
344
378
# if using consistency loss, also need the ema model predicted flow
345
379
346
380
if self .use_consistency :
347
381
delta_t = self .consistency_delta_time
348
- ema_flow , ema_pred_flow , ema_pred_data = get_noised_and_flows (self .ema_model , padded_times + delta_t )
382
+ ema_output , ema_flow , ema_pred_flow , ema_pred_data = get_noised_and_flows (self .ema_model , padded_times + delta_t )
383
+
384
+ # determine target, depending on objective
385
+
386
+ if self .predict == 'flow' :
387
+ target = flow
388
+ elif self .predict == 'noise' :
389
+ target = noise
390
+ else :
391
+ raise ValueError (f'unknown objective { self .predict } ' )
349
392
350
393
# losses
351
394
352
- main_flow_loss = self .loss_fn (pred_flow , flow , pred_data = pred_data , times = times , data = data )
395
+ main_loss = self .loss_fn (output , target , pred_data = pred_data , times = times , data = data )
353
396
354
397
consistency_loss = data_match_loss = velocity_match_loss = 0.
355
398
@@ -363,14 +406,14 @@ def get_noised_and_flows(model, t):
363
406
364
407
# total loss
365
408
366
- total_loss = main_flow_loss + consistency_loss * self .consistency_loss_weight
409
+ total_loss = main_loss + consistency_loss * self .consistency_loss_weight
367
410
368
411
if not return_loss_breakdown :
369
412
return total_loss
370
413
371
414
# loss breakdown
372
415
373
- return total_loss , LossBreakdown (total_loss , main_flow_loss , data_match_loss , velocity_match_loss )
416
+ return total_loss , LossBreakdown (total_loss , main_loss , data_match_loss , velocity_match_loss )
374
417
375
418
# reflow wrapper
376
419
0 commit comments