@@ -152,7 +152,9 @@ def __init__(
152
152
consistency_delta_time = 1e-3 ,
153
153
consistency_loss_weight = 1. ,
154
154
data_normalize_fn = normalize_to_neg_one_to_one ,
155
- data_unnormalize_fn = unnormalize_to_zero_to_one
155
+ data_unnormalize_fn = unnormalize_to_zero_to_one ,
156
+ clip_during_sampling = True ,
157
+ clip_values : Tuple [float , float ] = (- 1. , 1. )
156
158
):
157
159
super ().__init__ ()
158
160
@@ -202,6 +204,9 @@ def __init__(
202
204
self .odeint_kwargs = odeint_kwargs
203
205
self .data_shape = data_shape
204
206
207
+ self .clip_during_sampling = clip_during_sampling
208
+ self .clip_values = clip_values
209
+
205
210
# consistency flow matching
206
211
207
212
self .use_consistency = use_consistency
@@ -291,7 +296,16 @@ def sample(
291
296
data_shape = default (data_shape , self .data_shape )
292
297
assert exists (data_shape ), 'you need to either pass in a `data_shape` or have trained at least with one forward'
293
298
299
+ # clipping still helps for predict noise objective
300
+ # much like original ddpm paper trick
301
+
302
+ maybe_clip = (lambda t : t .clamp_ (* self .clip_values )) if self .clip_during_sampling else identity
303
+
304
+ # ode step function
305
+
294
306
def ode_fn (t , x ):
307
+ x = maybe_clip (x )
308
+
295
309
_ , flow = self .predict_flow (model , x , times = t , ** model_kwargs )
296
310
return flow
297
311
0 commit comments