@@ -672,7 +672,7 @@ def __init__(
672
672
if isinstance (folder , str ):
673
673
folder = Path (folder )
674
674
675
- assert folder .exists () and folder . is_dir ()
675
+ assert folder .is_dir ()
676
676
677
677
self .folder = folder
678
678
self .image_size = image_size
@@ -707,21 +707,62 @@ def __getitem__(self, index):
707
707
708
708
from torch .optim import Adam
709
709
from accelerate import Accelerator
710
+ from torch .utils .data import DataLoader
710
711
711
- class Trainer :
712
+ def cycle (dl ):
713
+ while True :
714
+ for batch in dl :
715
+ yield batch
716
+
717
+ class Trainer (Module ):
712
718
def __init__ (
713
719
self ,
714
720
rectified_flow : RectifiedFlow ,
715
721
* ,
716
722
dataset : Dataset ,
717
723
num_train_steps = 70_000 ,
718
- learning_rate : 3e-4 ,
724
+ learning_rate = 3e-4 ,
725
+ batch_size = 16 ,
719
726
adam_kwargs : dict = dict (),
720
727
accelerate_kwargs : dict = dict (),
721
728
checkpoints_folder : str = './checkpoints' ,
722
729
results_folder : str = './results'
723
730
):
724
- return self
731
+ super ().__init__ ()
732
+ self .accelerator = Accelerator (** accelerate_kwargs )
733
+
734
+ self .model = rectified_flow
735
+ self .optimizer = Adam (rectified_flow .parameters (), lr = learning_rate , ** adam_kwargs )
736
+ self .dl = DataLoader (dataset , batch_size = batch_size )
737
+
738
+ self .model , self .optimizer , self .dl = self .accelerator .prepare (self .model , self .optimizer , self .dl )
739
+
740
+ self .num_train_steps = num_train_steps
741
+
742
+ self .checkpoints_folder = Path (checkpoints_folder )
743
+ self .results_folder = Path (results_folder )
744
+
745
+ self .checkpoints_folder .mkdir (exist_ok = True , parents = True )
746
+ self .results_folder .mkdir (exist_ok = True , parents = True )
747
+
748
+ assert self .checkpoints_folder .is_dir ()
749
+ assert self .results_folder .is_dir ()
750
+
751
+ def forward (self ):
752
+
753
+ dl = cycle (self .dl )
754
+
755
+ for _ in range (self .num_train_steps ):
756
+ self .model .train ()
757
+
758
+ data = next (dl )
759
+ loss = self .model (data )
760
+
761
+ self .accelerator .print (f'loss: { loss .item ():.3f} ' )
762
+ self .accelerator .backward (loss )
763
+
764
+ self .optimizer .step ()
765
+ self .optimizer .zero_grad ()
766
+
767
+ print ('training complete' )
725
768
726
- def __call__ (self ):
727
- return self
0 commit comments