Skip to content

Commit cc6efa4

Browse files
committed
add ability to clip the flow during sampling, seems to help much better for predict noise objective
1 parent 74d1c3c commit cc6efa4

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.1.6"
3+
version = "0.1.7"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ def __init__(
154154
data_normalize_fn = normalize_to_neg_one_to_one,
155155
data_unnormalize_fn = unnormalize_to_zero_to_one,
156156
clip_during_sampling = False,
157-
clip_values: Tuple[float, float] = (-1.5, 1.5)
157+
clip_values: Tuple[float, float] = (-1., 1.),
158+
clip_flow_during_sampling = False, # this seems to help a lot when training with predict epsilon, at least for me
159+
clip_flow_values: Tuple[float, float] = (-3., 3)
158160
):
159161
super().__init__()
160162

@@ -204,8 +206,13 @@ def __init__(
204206
self.odeint_kwargs = odeint_kwargs
205207
self.data_shape = data_shape
206208

209+
# clipping for epsilon prediction
210+
207211
self.clip_during_sampling = clip_during_sampling
212+
self.clip_flow_during_sampling = clip_flow_during_sampling
213+
208214
self.clip_values = clip_values
215+
self.clip_flow_values = clip_flow_values
209216

210217
# consistency flow matching
211218

@@ -268,7 +275,7 @@ def predict_flow(self, model: Module, noised, *, times):
268275
noise = output
269276
padded_times = append_dims(times, noised.ndim - 1)
270277

271-
flow = (noised - noise) / padded_times.clamp(min = 1e-20)
278+
flow = (noised - noise) / padded_times
272279

273280
else:
274281
raise ValueError(f'unknown objective {self.predict}')
@@ -301,12 +308,17 @@ def sample(
301308

302309
maybe_clip = (lambda t: t.clamp_(*self.clip_values)) if self.clip_during_sampling else identity
303310

311+
maybe_clip_flow = (lambda t: t.clamp_(*self.clip_flow_values)) if self.clip_flow_during_sampling else identity
312+
304313
# ode step function
305314

306315
def ode_fn(t, x):
307316
x = maybe_clip(x)
308317

309318
_, flow = self.predict_flow(model, x, times = t, **model_kwargs)
319+
320+
flow = maybe_clip_flow(flow)
321+
310322
return flow
311323

312324
# start with random gaussian noise - y0

0 commit comments

Comments
 (0)