Skip to content

Commit 6c551ab

Browse files
committed
add sampling, show example in readme
1 parent 8fc6d8c commit 6c551ab

File tree

4 files changed

+54
-4
lines changed

4 files changed

+54
-4
lines changed

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,30 @@ sampled = reflow.sample()
6161
assert sampled.shape[1:] == images.shape[1:]
6262
```
6363

64+
With a `Trainer` based on `accelerate`
65+
66+
```python
67+
import torch
68+
from rectified_flow_pytorch import RectifiedFlow, ImageDataset, Unet, Trainer
69+
70+
model = Unet(dim = 64)
71+
72+
rectified_flow = RectifiedFlow(model)
73+
74+
img_dataset = ImageDataset(
75+
folder = './jpg',
76+
image_size = 256
77+
)
78+
79+
trainer = Trainer(
80+
rectified_flow,
81+
dataset = img_dataset,
82+
num_train_steps = 70_000
83+
)
84+
85+
trainer()
86+
```
87+
6488
## Citations
6589

6690
```bibtex

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.9"
3+
version = "0.0.10"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
RectifiedFlow,
33
Reflow,
44
ImageDataset,
5+
Unet,
56
Trainer
67
)

rectified_flow_pytorch/rectified_flow.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torchdiffeq import odeint
1212

1313
import torchvision
14+
from torchvision.utils import save_image
1415
from torchvision.models import VGG16_Weights
1516

1617
from einops import einsum, reduce, rearrange, repeat
@@ -726,7 +727,9 @@ def __init__(
726727
adam_kwargs: dict = dict(),
727728
accelerate_kwargs: dict = dict(),
728729
checkpoints_folder: str = './checkpoints',
729-
results_folder: str = './results'
730+
results_folder: str = './results',
731+
save_results_every: int = 100,
732+
num_samples: int = 16
730733
):
731734
super().__init__()
732735
self.accelerator = Accelerator(**accelerate_kwargs)
@@ -745,24 +748,46 @@ def __init__(
745748
self.checkpoints_folder.mkdir(exist_ok = True, parents = True)
746749
self.results_folder.mkdir(exist_ok = True, parents = True)
747750

751+
self.save_results_every = save_results_every
752+
753+
self.num_sample_rows = int(math.sqrt(num_samples))
754+
assert (self.num_sample_rows ** 2) == num_samples, f'{num_samples} must be a square'
755+
self.num_samples = num_samples
756+
748757
assert self.checkpoints_folder.is_dir()
749758
assert self.results_folder.is_dir()
750759

751760
def forward(self):
752761

753762
dl = cycle(self.dl)
754763

755-
for _ in range(self.num_train_steps):
764+
for ind in range(self.num_train_steps):
765+
step = ind + 1
766+
756767
self.model.train()
757768

758769
data = next(dl)
759770
loss = self.model(data)
760771

761-
self.accelerator.print(f'loss: {loss.item():.3f}')
772+
self.accelerator.print(f'[{step}] loss: {loss.item():.3f}')
762773
self.accelerator.backward(loss)
763774

764775
self.optimizer.step()
765776
self.optimizer.zero_grad()
766777

778+
if not divisible_by(step, self.save_results_every):
779+
continue
780+
781+
self.accelerator.wait_for_everyone()
782+
783+
if self.accelerator.is_main_process:
784+
self.model.eval()
785+
sampled = self.model.sample(batch_size = self.num_samples)
786+
sampled.clamp_(0., 1.)
787+
788+
save_image(sampled, str(self.results_folder / f'results.{step}.png'), nrow = self.num_sample_rows)
789+
790+
self.accelerator.wait_for_everyone()
791+
767792
print('training complete')
768793

0 commit comments

Comments
 (0)