@@ -30,6 +30,14 @@ def append_dims(t, ndims):
30
30
shape = t .shape
31
31
return t .reshape (* shape , * ((1 ,) * ndims ))
32
32
33
+ # normalizing helpers
34
+
35
+ def normalize_to_neg_one_to_one (img ):
36
+ return img * 2 - 1
37
+
38
+ def unnormalize_to_zero_to_one (t ):
39
+ return (t + 1 ) * 0.5
40
+
33
41
# losses
34
42
35
43
class LPIPSLoss (Module ):
@@ -100,7 +108,9 @@ def __init__(
100
108
loss_fn : Literal ['mse' , 'pseudo_huber' ] | Module = 'mse' ,
101
109
loss_fn_kwargs : dict = dict (),
102
110
data_shape : Tuple [int , ...] | None = None ,
103
- immiscible = False
111
+ immiscible = False ,
112
+ data_normalize_fn = normalize_to_neg_one_to_one ,
113
+ data_unnormalize_fn = unnormalize_to_zero_to_one
104
114
):
105
115
super ().__init__ ()
106
116
self .model = model
@@ -135,6 +145,11 @@ def __init__(
135
145
136
146
self .immiscible = immiscible
137
147
148
+ # normalizing fn
149
+
150
+ self .data_normalize_fn = data_normalize_fn
151
+ self .data_unnormalize_fn = data_unnormalize_fn
152
+
138
153
@property
139
154
def device (self ):
140
155
return next (self .model .parameters ()).device
@@ -177,7 +192,8 @@ def ode_fn(t, x):
177
192
sampled_data = trajectory [- 1 ]
178
193
179
194
self .train (was_training )
180
- return sampled_data
195
+
196
+ return self .data_unnormalize_fn (sampled_data )
181
197
182
198
def forward (
183
199
self ,
@@ -187,6 +203,8 @@ def forward(
187
203
):
188
204
batch , * data_shape = data .shape
189
205
206
+ data = self .data_normalize_fn (data )
207
+
190
208
self .data_shape = default (self .data_shape , data_shape )
191
209
192
210
# x0 - gaussian noise, x1 - data
0 commit comments