Skip to content

Commit 2919f23

Browse files
committed
default functions for normalizing image data from 0 to 1 to -1 to 1 and back when sampling
1 parent 4bb628b commit 2919f23

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

rectified_flow_pytorch/rectified_flow.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ def append_dims(t, ndims):
3030
shape = t.shape
3131
return t.reshape(*shape, *((1,) * ndims))
3232

33+
# normalizing helpers
34+
35+
def normalize_to_neg_one_to_one(img):
36+
return img * 2 - 1
37+
38+
def unnormalize_to_zero_to_one(t):
39+
return (t + 1) * 0.5
40+
3341
# losses
3442

3543
class LPIPSLoss(Module):
@@ -100,7 +108,9 @@ def __init__(
100108
loss_fn: Literal['mse', 'pseudo_huber'] | Module = 'mse',
101109
loss_fn_kwargs: dict = dict(),
102110
data_shape: Tuple[int, ...] | None = None,
103-
immiscible = False
111+
immiscible = False,
112+
data_normalize_fn = normalize_to_neg_one_to_one,
113+
data_unnormalize_fn = unnormalize_to_zero_to_one
104114
):
105115
super().__init__()
106116
self.model = model
@@ -135,6 +145,11 @@ def __init__(
135145

136146
self.immiscible = immiscible
137147

148+
# normalizing fn
149+
150+
self.data_normalize_fn = data_normalize_fn
151+
self.data_unnormalize_fn = data_unnormalize_fn
152+
138153
@property
139154
def device(self):
140155
return next(self.model.parameters()).device
@@ -177,7 +192,8 @@ def ode_fn(t, x):
177192
sampled_data = trajectory[-1]
178193

179194
self.train(was_training)
180-
return sampled_data
195+
196+
return self.data_unnormalize_fn(sampled_data)
181197

182198
def forward(
183199
self,
@@ -187,6 +203,8 @@ def forward(
187203
):
188204
batch, *data_shape = data.shape
189205

206+
data = self.data_normalize_fn(data)
207+
190208
self.data_shape = default(self.data_shape, data_shape)
191209

192210
# x0 - gaussian noise, x1 - data

0 commit comments

Comments
 (0)