@@ -154,7 +154,9 @@ def __init__(
154
154
data_normalize_fn = normalize_to_neg_one_to_one ,
155
155
data_unnormalize_fn = unnormalize_to_zero_to_one ,
156
156
clip_during_sampling = False ,
157
- clip_values : Tuple [float , float ] = (- 1.5 , 1.5 )
157
+ clip_values : Tuple [float , float ] = (- 1. , 1. ),
158
+ clip_flow_during_sampling = False , # this seems to help a lot when training with predict epsilon, at least for me
159
+ clip_flow_values : Tuple [float , float ] = (- 3. , 3 )
158
160
):
159
161
super ().__init__ ()
160
162
@@ -204,8 +206,13 @@ def __init__(
204
206
self .odeint_kwargs = odeint_kwargs
205
207
self .data_shape = data_shape
206
208
209
+ # clipping for epsilon prediction
210
+
207
211
self .clip_during_sampling = clip_during_sampling
212
+ self .clip_flow_during_sampling = clip_flow_during_sampling
213
+
208
214
self .clip_values = clip_values
215
+ self .clip_flow_values = clip_flow_values
209
216
210
217
# consistency flow matching
211
218
@@ -268,7 +275,7 @@ def predict_flow(self, model: Module, noised, *, times):
268
275
noise = output
269
276
padded_times = append_dims (times , noised .ndim - 1 )
270
277
271
- flow = (noised - noise ) / padded_times . clamp ( min = 1e-20 )
278
+ flow = (noised - noise ) / padded_times
272
279
273
280
else :
274
281
raise ValueError (f'unknown objective { self .predict } ' )
@@ -301,12 +308,17 @@ def sample(
301
308
302
309
maybe_clip = (lambda t : t .clamp_ (* self .clip_values )) if self .clip_during_sampling else identity
303
310
311
+ maybe_clip_flow = (lambda t : t .clamp_ (* self .clip_flow_values )) if self .clip_flow_during_sampling else identity
312
+
304
313
# ode step function
305
314
306
315
def ode_fn (t , x ):
307
316
x = maybe_clip (x )
308
317
309
318
_ , flow = self .predict_flow (model , x , times = t , ** model_kwargs )
319
+
320
+ flow = maybe_clip_flow (flow )
321
+
310
322
return flow
311
323
312
324
# start with random gaussian noise - y0
0 commit comments