@@ -219,8 +219,14 @@ def sample(
219
219
steps = 16 ,
220
220
noise = None ,
221
221
data_shape : Tuple [int , ...] | None = None ,
222
+ use_ema : bool = False ,
222
223
** model_kwargs
223
224
):
225
+ use_ema = default (use_ema , self .use_consistency )
226
+ assert not (use_ema and not self .use_consistency ), 'in order to sample from an ema model, you must have `use_consistency` turned on'
227
+
228
+ model = self .ema_model if use_ema else self .model
229
+
224
230
was_training = self .training
225
231
self .eval ()
226
232
@@ -234,7 +240,7 @@ def ode_fn(t, x):
234
240
t = repeat (t , '-> b' , b = x .shape [0 ])
235
241
model_kwargs .update (** {time_kwarg : t })
236
242
237
- return self . model (x , ** model_kwargs )
243
+ return model (x , ** model_kwargs )
238
244
239
245
# start with random gaussian noise - y0
240
246
@@ -825,34 +831,41 @@ def __init__(
825
831
adam_kwargs : dict = dict (),
826
832
accelerate_kwargs : dict = dict (),
827
833
ema_kwargs : dict = dict (),
828
- use_consistency_ema = False # whether to just use the EMA from the velocity consistency from the consistency FM paper
834
+ use_ema = True
829
835
):
830
836
super ().__init__ ()
831
837
self .accelerator = Accelerator (** accelerate_kwargs )
832
838
833
839
self .model = rectified_flow
834
840
835
- if self .is_main :
836
- if use_consistency_ema :
837
- assert self .model .use_consistency , 'model must be using the consistency EMA for it to be reused as the main EMA model during sampling'
841
+ # determine whether to keep track of EMA (if not using consistency FM)
842
+ # which will determine which model to use for sampling
838
843
839
- self .ema_model = self .model .ema_model
840
- else :
841
- self .ema_model = EMA (
842
- self .model ,
843
- forward_method_names = ('sample' ,),
844
- ** ema_kwargs
845
- )
844
+ use_ema &= not self .model .use_consistency
845
+
846
+ self .use_ema = use_ema
847
+ self .ema_model = None
848
+
849
+ if self .is_main and use_ema :
850
+ self .ema_model = EMA (
851
+ self .model ,
852
+ forward_method_names = ('sample' ,),
853
+ ** ema_kwargs
854
+ )
846
855
847
856
self .ema_model .to (self .accelerator .device )
848
857
858
+ # optimizer, dataloader, and all that
859
+
849
860
self .optimizer = Adam (rectified_flow .parameters (), lr = learning_rate , ** adam_kwargs )
850
861
self .dl = DataLoader (dataset , batch_size = batch_size , shuffle = True , drop_last = True )
851
862
852
863
self .model , self .optimizer , self .dl = self .accelerator .prepare (self .model , self .optimizer , self .dl )
853
864
854
865
self .num_train_steps = num_train_steps
855
866
867
+ # folders
868
+
856
869
self .checkpoints_folder = Path (checkpoints_folder )
857
870
self .results_folder = Path (results_folder )
858
871
@@ -906,17 +919,20 @@ def forward(self):
906
919
if self .model .use_consistency :
907
920
self .model .ema_model .update ()
908
921
909
- if self .is_main :
922
+ if self .is_main and self .use_ema :
923
+ self .ema_model .ema_model .data_shape = self .model .data_shape
924
+
910
925
self .ema_model .update ()
911
926
912
927
self .accelerator .wait_for_everyone ()
913
928
914
929
if self .is_main :
930
+ eval_model = default (self .ema_model , self .model )
931
+
915
932
if divisible_by (step , self .save_results_every ):
916
- self .ema_model .ema_model .data_shape = self .model .data_shape
917
933
918
934
with torch .no_grad ():
919
- sampled = self . ema_model .sample (batch_size = self .num_samples )
935
+ sampled = eval_model .sample (batch_size = self .num_samples )
920
936
921
937
sampled .clamp_ (0. , 1. )
922
938
save_image (sampled , str (self .results_folder / f'results.{ step } .png' ), nrow = self .num_sample_rows )
0 commit comments