Skip to content

Commit 2598074

Browse files
committed
make sure to handle all scenarios for sampling depending on whether consistency FM is turned on
1 parent 92f73ff commit 2598074

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.23"
3+
version = "0.0.24"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,14 @@ def sample(
219219
steps = 16,
220220
noise = None,
221221
data_shape: Tuple[int, ...] | None = None,
222+
use_ema: bool = False,
222223
**model_kwargs
223224
):
225+
use_ema = default(use_ema, self.use_consistency)
226+
assert not (use_ema and not self.use_consistency), 'in order to sample from an ema model, you must have `use_consistency` turned on'
227+
228+
model = self.ema_model if use_ema else self.model
229+
224230
was_training = self.training
225231
self.eval()
226232

@@ -234,7 +240,7 @@ def ode_fn(t, x):
234240
t = repeat(t, '-> b', b = x.shape[0])
235241
model_kwargs.update(**{time_kwarg: t})
236242

237-
return self.model(x, **model_kwargs)
243+
return model(x, **model_kwargs)
238244

239245
# start with random gaussian noise - y0
240246

@@ -825,34 +831,41 @@ def __init__(
825831
adam_kwargs: dict = dict(),
826832
accelerate_kwargs: dict = dict(),
827833
ema_kwargs: dict = dict(),
828-
use_consistency_ema = False # whether to just use the EMA from the velocity consistency from the consistency FM paper
834+
use_ema = True
829835
):
830836
super().__init__()
831837
self.accelerator = Accelerator(**accelerate_kwargs)
832838

833839
self.model = rectified_flow
834840

835-
if self.is_main:
836-
if use_consistency_ema:
837-
assert self.model.use_consistency, 'model must be using the consistency EMA for it to be reused as the main EMA model during sampling'
841+
# determine whether to keep track of EMA (if not using consistency FM)
842+
# which will determine which model to use for sampling
838843

839-
self.ema_model = self.model.ema_model
840-
else:
841-
self.ema_model = EMA(
842-
self.model,
843-
forward_method_names = ('sample',),
844-
**ema_kwargs
845-
)
844+
use_ema &= not self.model.use_consistency
845+
846+
self.use_ema = use_ema
847+
self.ema_model = None
848+
849+
if self.is_main and use_ema:
850+
self.ema_model = EMA(
851+
self.model,
852+
forward_method_names = ('sample',),
853+
**ema_kwargs
854+
)
846855

847856
self.ema_model.to(self.accelerator.device)
848857

858+
# optimizer, dataloader, and all that
859+
849860
self.optimizer = Adam(rectified_flow.parameters(), lr = learning_rate, **adam_kwargs)
850861
self.dl = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
851862

852863
self.model, self.optimizer, self.dl = self.accelerator.prepare(self.model, self.optimizer, self.dl)
853864

854865
self.num_train_steps = num_train_steps
855866

867+
# folders
868+
856869
self.checkpoints_folder = Path(checkpoints_folder)
857870
self.results_folder = Path(results_folder)
858871

@@ -906,17 +919,20 @@ def forward(self):
906919
if self.model.use_consistency:
907920
self.model.ema_model.update()
908921

909-
if self.is_main:
922+
if self.is_main and self.use_ema:
923+
self.ema_model.ema_model.data_shape = self.model.data_shape
924+
910925
self.ema_model.update()
911926

912927
self.accelerator.wait_for_everyone()
913928

914929
if self.is_main:
930+
eval_model = default(self.ema_model, self.model)
931+
915932
if divisible_by(step, self.save_results_every):
916-
self.ema_model.ema_model.data_shape = self.model.data_shape
917933

918934
with torch.no_grad():
919-
sampled = self.ema_model.sample(batch_size = self.num_samples)
935+
sampled = eval_model.sample(batch_size = self.num_samples)
920936

921937
sampled.clamp_(0., 1.)
922938
save_image(sampled, str(self.results_folder / f'results.{step}.png'), nrow = self.num_sample_rows)

0 commit comments

Comments
 (0)