Skip to content

Commit 1c47738

Browse files
committed
move the reflow into another file, as it may be obsoleted by the consistency fm paper
1 parent 73a0ca7 commit 1c47738

File tree

4 files changed

+198
-176
lines changed

4 files changed

+198
-176
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from rectified_flow_pytorch.rectified_flow import (
22
RectifiedFlow,
3-
Reflow,
43
ImageDataset,
54
Unet,
65
Trainer,
7-
ReflowTrainer
86
)
7+
8+
from rectified_flow_pytorch.reflow import (
9+
Reflow,
10+
ReflowTrainer
11+
)

rectified_flow_pytorch/rectified_flow.py

Lines changed: 0 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -415,56 +415,6 @@ def get_noised_and_flows(model, t):
415415

416416
return total_loss, LossBreakdown(total_loss, main_loss, data_match_loss, velocity_match_loss)
417417

418-
# reflow wrapper
419-
420-
class Reflow(Module):
421-
def __init__(
422-
self,
423-
rectified_flow: RectifiedFlow,
424-
frozen_model: RectifiedFlow | None = None,
425-
*,
426-
batch_size = 16,
427-
428-
):
429-
super().__init__()
430-
model, data_shape = rectified_flow.model, rectified_flow.data_shape
431-
assert exists(data_shape), '`data_shape` must be defined in RectifiedFlow'
432-
433-
self.batch_size = batch_size
434-
self.data_shape = data_shape
435-
436-
self.model = rectified_flow
437-
438-
if not exists(frozen_model):
439-
# make a frozen copy of the model and set requires grad to be False for all parameters for safe measure
440-
441-
frozen_model = deepcopy(rectified_flow)
442-
443-
for p in frozen_model.parameters():
444-
p.detach_()
445-
446-
self.frozen_model = frozen_model
447-
448-
def device(self):
449-
return next(self.parameters()).device
450-
451-
def parameters(self):
452-
return self.model.parameters() # omit frozen model
453-
454-
def sample(self, *args, **kwargs):
455-
return self.model.sample(*args, **kwargs)
456-
457-
def forward(self):
458-
459-
noise = torch.randn((self.batch_size, *self.data_shape), device = self.device)
460-
sampled_output = self.frozen_model.sample(noise = noise)
461-
462-
# the coupling in the paper is (noise, sampled_output)
463-
464-
loss = self.model(sampled_output, noise = noise)
465-
466-
return loss
467-
468418
# unet
469419

470420
from functools import partial
@@ -1022,126 +972,3 @@ def forward(self):
1022972
self.accelerator.wait_for_everyone()
1023973

1024974
print('training complete')
1025-
1026-
# reflow trainer
1027-
1028-
class ReflowTrainer(Module):
1029-
def __init__(
1030-
self,
1031-
rectified_flow: RectifiedFlow,
1032-
*,
1033-
num_train_steps = 70_000,
1034-
learning_rate = 3e-4,
1035-
batch_size = 16,
1036-
checkpoints_folder: str = './checkpoints',
1037-
results_folder: str = './results',
1038-
save_results_every: int = 100,
1039-
checkpoint_every: int = 1000,
1040-
num_samples: int = 16,
1041-
adam_kwargs: dict = dict(),
1042-
accelerate_kwargs: dict = dict(),
1043-
ema_kwargs: dict = dict()
1044-
):
1045-
super().__init__()
1046-
self.accelerator = Accelerator(**accelerate_kwargs)
1047-
1048-
assert not rectified_flow.use_consistency, 'reflow is not needed if using consistency flow matching'
1049-
1050-
self.model = Reflow(rectified_flow)
1051-
1052-
if self.is_main:
1053-
self.ema_model = EMA(
1054-
self.model,
1055-
forward_method_names = ('sample',),
1056-
**ema_kwargs
1057-
)
1058-
1059-
self.ema_model.to(self.accelerator.device)
1060-
1061-
self.optimizer = Adam(rectified_flow.parameters(), lr = learning_rate, **adam_kwargs)
1062-
1063-
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
1064-
1065-
self.batch_size = batch_size
1066-
self.num_train_steps = num_train_steps
1067-
1068-
self.checkpoints_folder = Path(checkpoints_folder)
1069-
self.results_folder = Path(results_folder)
1070-
1071-
self.checkpoints_folder.mkdir(exist_ok = True, parents = True)
1072-
self.results_folder.mkdir(exist_ok = True, parents = True)
1073-
1074-
self.checkpoint_every = checkpoint_every
1075-
self.save_results_every = save_results_every
1076-
1077-
self.num_sample_rows = int(math.sqrt(num_samples))
1078-
assert (self.num_sample_rows ** 2) == num_samples, f'{num_samples} must be a square'
1079-
self.num_samples = num_samples
1080-
1081-
assert self.checkpoints_folder.is_dir()
1082-
assert self.results_folder.is_dir()
1083-
1084-
@property
1085-
def is_main(self):
1086-
return self.accelerator.is_main_process
1087-
1088-
def log(self, *args, **kwargs):
1089-
return self.accelerator.log(*args, **kwargs)
1090-
1091-
def log_images(self, *args, **kwargs):
1092-
return self.accelerator.log(*args, **kwargs)
1093-
1094-
def save(self, path):
1095-
if not self.is_main:
1096-
return
1097-
1098-
save_package = dict(
1099-
model = self.accelerator.unwrap_model(self.model).state_dict(),
1100-
ema_model = self.ema_model.state_dict(),
1101-
optimizer = self.accelerator.unwrap_model(self.optimizer).state_dict(),
1102-
)
1103-
1104-
torch.save(save_package, str(self.checkpoints_folder / path))
1105-
1106-
def forward(self):
1107-
1108-
for ind in range(self.num_train_steps):
1109-
step = ind + 1
1110-
1111-
self.model.train()
1112-
1113-
loss = self.model(batch_size = self.batch_size)
1114-
1115-
self.log(loss, step = step)
1116-
1117-
self.accelerator.print(f'[{step}] reflow loss: {loss.item():.3f}')
1118-
self.accelerator.backward(loss)
1119-
1120-
self.optimizer.step()
1121-
self.optimizer.zero_grad()
1122-
1123-
if self.is_main:
1124-
self.ema_model.update()
1125-
1126-
self.accelerator.wait_for_everyone()
1127-
1128-
if self.is_main:
1129-
if divisible_by(step, self.save_results_every):
1130-
self.ema_model.ema_model.data_shape = self.model.data_shape
1131-
1132-
with torch.no_grad():
1133-
sampled = self.ema_model.sample(batch_size = self.num_samples)
1134-
1135-
sampled = rearrange(sampled, '(row col) c h w -> c (row h) (col w)', row = self.num_sample_rows)
1136-
sampled.clamp_(0., 1.)
1137-
1138-
self.log_images(sampled, step = step)
1139-
1140-
save_image(sampled, str(self.results_folder / f'results.{step}.png'))
1141-
1142-
if divisible_by(step, self.checkpoint_every):
1143-
self.save(f'checkpoint.{step}.pt')
1144-
1145-
self.accelerator.wait_for_everyone()
1146-
1147-
print('reflow training complete')

rectified_flow_pytorch/reflow.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
from copy import deepcopy
2+
from pathlib import Path
3+
4+
import torch
5+
from torch.optim import Adam
6+
from torch.nn import Module, ModuleList
7+
8+
from rectified_flow_pytorch.rectified_flow import RectifiedFlow
9+
10+
from ema_pytorch import EMA
11+
from accelerate import Accelerator
12+
13+
# helpers
14+
15+
def exists(v):
16+
return v is not None
17+
18+
def default(v, d):
19+
return v if exists(v) else d
20+
21+
# reflow wrapper
22+
23+
class Reflow(Module):
24+
def __init__(
25+
self,
26+
rectified_flow: RectifiedFlow,
27+
frozen_model: RectifiedFlow | None = None,
28+
*,
29+
batch_size = 16,
30+
31+
):
32+
super().__init__()
33+
model, data_shape = rectified_flow.model, rectified_flow.data_shape
34+
assert exists(data_shape), '`data_shape` must be defined in RectifiedFlow'
35+
36+
self.batch_size = batch_size
37+
self.data_shape = data_shape
38+
39+
self.model = rectified_flow
40+
41+
if not exists(frozen_model):
42+
# make a frozen copy of the model and set requires grad to be False for all parameters for safe measure
43+
44+
frozen_model = deepcopy(rectified_flow)
45+
46+
for p in frozen_model.parameters():
47+
p.detach_()
48+
49+
self.frozen_model = frozen_model
50+
51+
def device(self):
52+
return next(self.parameters()).device
53+
54+
def parameters(self):
55+
return self.model.parameters() # omit frozen model
56+
57+
def sample(self, *args, **kwargs):
58+
return self.model.sample(*args, **kwargs)
59+
60+
def forward(self):
61+
62+
noise = torch.randn((self.batch_size, *self.data_shape), device = self.device)
63+
sampled_output = self.frozen_model.sample(noise = noise)
64+
65+
# the coupling in the paper is (noise, sampled_output)
66+
67+
loss = self.model(sampled_output, noise = noise)
68+
69+
return loss
70+
71+
# reflow trainer
72+
73+
class ReflowTrainer(Module):
74+
def __init__(
75+
self,
76+
rectified_flow: RectifiedFlow,
77+
*,
78+
num_train_steps = 70_000,
79+
learning_rate = 3e-4,
80+
batch_size = 16,
81+
checkpoints_folder: str = './checkpoints',
82+
results_folder: str = './results',
83+
save_results_every: int = 100,
84+
checkpoint_every: int = 1000,
85+
num_samples: int = 16,
86+
adam_kwargs: dict = dict(),
87+
accelerate_kwargs: dict = dict(),
88+
ema_kwargs: dict = dict()
89+
):
90+
super().__init__()
91+
self.accelerator = Accelerator(**accelerate_kwargs)
92+
93+
assert not rectified_flow.use_consistency, 'reflow is not needed if using consistency flow matching'
94+
95+
self.model = Reflow(rectified_flow)
96+
97+
if self.is_main:
98+
self.ema_model = EMA(
99+
self.model,
100+
forward_method_names = ('sample',),
101+
**ema_kwargs
102+
)
103+
104+
self.ema_model.to(self.accelerator.device)
105+
106+
self.optimizer = Adam(rectified_flow.parameters(), lr = learning_rate, **adam_kwargs)
107+
108+
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
109+
110+
self.batch_size = batch_size
111+
self.num_train_steps = num_train_steps
112+
113+
self.checkpoints_folder = Path(checkpoints_folder)
114+
self.results_folder = Path(results_folder)
115+
116+
self.checkpoints_folder.mkdir(exist_ok = True, parents = True)
117+
self.results_folder.mkdir(exist_ok = True, parents = True)
118+
119+
self.checkpoint_every = checkpoint_every
120+
self.save_results_every = save_results_every
121+
122+
self.num_sample_rows = int(math.sqrt(num_samples))
123+
assert (self.num_sample_rows ** 2) == num_samples, f'{num_samples} must be a square'
124+
self.num_samples = num_samples
125+
126+
assert self.checkpoints_folder.is_dir()
127+
assert self.results_folder.is_dir()
128+
129+
@property
130+
def is_main(self):
131+
return self.accelerator.is_main_process
132+
133+
def log(self, *args, **kwargs):
134+
return self.accelerator.log(*args, **kwargs)
135+
136+
def log_images(self, *args, **kwargs):
137+
return self.accelerator.log(*args, **kwargs)
138+
139+
def save(self, path):
140+
if not self.is_main:
141+
return
142+
143+
save_package = dict(
144+
model = self.accelerator.unwrap_model(self.model).state_dict(),
145+
ema_model = self.ema_model.state_dict(),
146+
optimizer = self.accelerator.unwrap_model(self.optimizer).state_dict(),
147+
)
148+
149+
torch.save(save_package, str(self.checkpoints_folder / path))
150+
151+
def forward(self):
152+
153+
for ind in range(self.num_train_steps):
154+
step = ind + 1
155+
156+
self.model.train()
157+
158+
loss = self.model(batch_size = self.batch_size)
159+
160+
self.log(loss, step = step)
161+
162+
self.accelerator.print(f'[{step}] reflow loss: {loss.item():.3f}')
163+
self.accelerator.backward(loss)
164+
165+
self.optimizer.step()
166+
self.optimizer.zero_grad()
167+
168+
if self.is_main:
169+
self.ema_model.update()
170+
171+
self.accelerator.wait_for_everyone()
172+
173+
if self.is_main:
174+
if divisible_by(step, self.save_results_every):
175+
self.ema_model.ema_model.data_shape = self.model.data_shape
176+
177+
with torch.no_grad():
178+
sampled = self.ema_model.sample(batch_size = self.num_samples)
179+
180+
sampled = rearrange(sampled, '(row col) c h w -> c (row h) (col w)', row = self.num_sample_rows)
181+
sampled.clamp_(0., 1.)
182+
183+
self.log_images(sampled, step = step)
184+
185+
save_image(sampled, str(self.results_folder / f'results.{step}.png'))
186+
187+
if divisible_by(step, self.checkpoint_every):
188+
self.save(f'checkpoint.{step}.pt')
189+
190+
self.accelerator.wait_for_everyone()
191+
192+
print('reflow training complete')

0 commit comments

Comments
 (0)