|
4 | 4 | from typing import Tuple, List, Literal, Callable
|
5 | 5 |
|
6 | 6 | import torch
|
| 7 | +from torch import Tensor |
7 | 8 | from torch import nn, pi, from_numpy
|
8 | 9 | from torch.nn import Module, ModuleList
|
9 | 10 | import torch.nn.functional as F
|
@@ -263,7 +264,8 @@ def ode_fn(t, x):
|
263 | 264 | def forward(
|
264 | 265 | self,
|
265 | 266 | data,
|
266 |
| - noise = None, |
| 267 | + noise: Tensor | None = None, |
| 268 | + return_loss_breakdown = False, |
267 | 269 | **model_kwargs
|
268 | 270 | ):
|
269 | 271 | batch, *data_shape = data.shape
|
@@ -337,21 +339,28 @@ def get_noised_and_flows(model, t):
|
337 | 339 |
|
338 | 340 | # losses
|
339 | 341 |
|
340 |
| - loss = self.loss_fn(pred_flow, flow, pred_data = pred_data, times = times, data = data) |
| 342 | + main_flow_loss = self.loss_fn(pred_flow, flow, pred_data = pred_data, times = times, data = data) |
| 343 | + |
| 344 | + consistency_loss = data_match_loss = velocity_match_loss = 0. |
341 | 345 |
|
342 | 346 | if self.use_consistency:
|
343 |
| - # add velocity consistency loss from consistency fm paper - eq (6) in https://arxiv.org/html/2407.02398v1 |
| 347 | + # consistency losses from consistency fm paper - eq (6) in https://arxiv.org/html/2407.02398v1 |
344 | 348 |
|
345 |
| - α = self.consistency_velocity_match_alpha |
| 349 | + data_match_loss = F.mse_loss(pred_data, ema_pred_data) |
| 350 | + velocity_match_loss = F.mse_loss(pred_flow, ema_pred_flow) |
346 | 351 |
|
347 |
| - consistency_loss = ( |
348 |
| - F.mse_loss(pred_data, ema_pred_data) + |
349 |
| - α * F.mse_loss(pred_flow, ema_pred_flow) |
350 |
| - ) |
| 352 | + consistency_loss = data_match_loss + velocity_match_loss * self.consistency_velocity_match_alpha |
351 | 353 |
|
352 |
| - loss = loss + consistency_loss * self.consistency_loss_weight |
| 354 | + # total loss |
353 | 355 |
|
354 |
| - return loss |
| 356 | + total_loss = main_flow_loss + consistency_loss * self.consistency_loss_weight |
| 357 | + |
| 358 | + if not return_loss_breakdown: |
| 359 | + return total_loss |
| 360 | + |
| 361 | + # loss breakdown |
| 362 | + |
| 363 | + return loss, (main_flow_loss, data_match_loss, velocity_match_loss) |
355 | 364 |
|
356 | 365 | # reflow wrapper
|
357 | 366 |
|
|
0 commit comments