@@ -154,7 +154,9 @@ def __init__(
154154 data_normalize_fn = normalize_to_neg_one_to_one ,
155155 data_unnormalize_fn = unnormalize_to_zero_to_one ,
156156 clip_during_sampling = False ,
157- clip_values : Tuple [float , float ] = (- 1.5 , 1.5 )
157+ clip_values : Tuple [float , float ] = (- 1. , 1. ),
158+ clip_flow_during_sampling = False , # this seems to help a lot when training with predict epsilon, at least for me
159+ clip_flow_values : Tuple [float , float ] = (- 3. , 3 )
158160 ):
159161 super ().__init__ ()
160162
@@ -204,8 +206,13 @@ def __init__(
204206 self .odeint_kwargs = odeint_kwargs
205207 self .data_shape = data_shape
206208
209+ # clipping for epsilon prediction
210+
207211 self .clip_during_sampling = clip_during_sampling
212+ self .clip_flow_during_sampling = clip_flow_during_sampling
213+
208214 self .clip_values = clip_values
215+ self .clip_flow_values = clip_flow_values
209216
210217 # consistency flow matching
211218
@@ -268,7 +275,7 @@ def predict_flow(self, model: Module, noised, *, times):
268275 noise = output
269276 padded_times = append_dims (times , noised .ndim - 1 )
270277
271- flow = (noised - noise ) / padded_times . clamp ( min = 1e-20 )
278+ flow = (noised - noise ) / padded_times
272279
273280 else :
274281 raise ValueError (f'unknown objective { self .predict } ' )
@@ -301,12 +308,17 @@ def sample(
301308
302309 maybe_clip = (lambda t : t .clamp_ (* self .clip_values )) if self .clip_during_sampling else identity
303310
311+ maybe_clip_flow = (lambda t : t .clamp_ (* self .clip_flow_values )) if self .clip_flow_during_sampling else identity
312+
304313 # ode step function
305314
306315 def ode_fn (t , x ):
307316 x = maybe_clip (x )
308317
309318 _ , flow = self .predict_flow (model , x , times = t , ** model_kwargs )
319+
320+ flow = maybe_clip_flow (flow )
321+
310322 return flow
311323
312324 # start with random gaussian noise - y0
0 commit comments