Skip to content

Commit 8fc6d8c

Browse files
committed
loss goes down
1 parent 81ad431 commit 8fc6d8c

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.8"
3+
version = "0.0.9"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def __init__(
672672
if isinstance(folder, str):
673673
folder = Path(folder)
674674

675-
assert folder.exists() and folder.is_dir()
675+
assert folder.is_dir()
676676

677677
self.folder = folder
678678
self.image_size = image_size
@@ -707,21 +707,62 @@ def __getitem__(self, index):
707707

708708
from torch.optim import Adam
709709
from accelerate import Accelerator
710+
from torch.utils.data import DataLoader
710711

711-
class Trainer:
712+
def cycle(dl):
713+
while True:
714+
for batch in dl:
715+
yield batch
716+
717+
class Trainer(Module):
712718
def __init__(
713719
self,
714720
rectified_flow: RectifiedFlow,
715721
*,
716722
dataset: Dataset,
717723
num_train_steps = 70_000,
718-
learning_rate: 3e-4,
724+
learning_rate = 3e-4,
725+
batch_size = 16,
719726
adam_kwargs: dict = dict(),
720727
accelerate_kwargs: dict = dict(),
721728
checkpoints_folder: str = './checkpoints',
722729
results_folder: str = './results'
723730
):
724-
return self
731+
super().__init__()
732+
self.accelerator = Accelerator(**accelerate_kwargs)
733+
734+
self.model = rectified_flow
735+
self.optimizer = Adam(rectified_flow.parameters(), lr = learning_rate, **adam_kwargs)
736+
self.dl = DataLoader(dataset, batch_size = batch_size)
737+
738+
self.model, self.optimizer, self.dl = self.accelerator.prepare(self.model, self.optimizer, self.dl)
739+
740+
self.num_train_steps = num_train_steps
741+
742+
self.checkpoints_folder = Path(checkpoints_folder)
743+
self.results_folder = Path(results_folder)
744+
745+
self.checkpoints_folder.mkdir(exist_ok = True, parents = True)
746+
self.results_folder.mkdir(exist_ok = True, parents = True)
747+
748+
assert self.checkpoints_folder.is_dir()
749+
assert self.results_folder.is_dir()
750+
751+
def forward(self):
752+
753+
dl = cycle(self.dl)
754+
755+
for _ in range(self.num_train_steps):
756+
self.model.train()
757+
758+
data = next(dl)
759+
loss = self.model(data)
760+
761+
self.accelerator.print(f'loss: {loss.item():.3f}')
762+
self.accelerator.backward(loss)
763+
764+
self.optimizer.step()
765+
self.optimizer.zero_grad()
766+
767+
print('training complete')
725768

726-
def __call__(self):
727-
return self

0 commit comments

Comments
 (0)