Skip to content

Commit d6d352e

Browse files
committed
for predict noise objective, clipping still helps a lot
1 parent 1c47738 commit d6d352e

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.1.2"
3+
version = "0.1.4"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ def __init__(
152152
consistency_delta_time = 1e-3,
153153
consistency_loss_weight = 1.,
154154
data_normalize_fn = normalize_to_neg_one_to_one,
155-
data_unnormalize_fn = unnormalize_to_zero_to_one
155+
data_unnormalize_fn = unnormalize_to_zero_to_one,
156+
clip_during_sampling = True,
157+
clip_values: Tuple[float, float] = (-1., 1.)
156158
):
157159
super().__init__()
158160

@@ -202,6 +204,9 @@ def __init__(
202204
self.odeint_kwargs = odeint_kwargs
203205
self.data_shape = data_shape
204206

207+
self.clip_during_sampling = clip_during_sampling
208+
self.clip_values = clip_values
209+
205210
# consistency flow matching
206211

207212
self.use_consistency = use_consistency
@@ -291,7 +296,16 @@ def sample(
291296
data_shape = default(data_shape, self.data_shape)
292297
assert exists(data_shape), 'you need to either pass in a `data_shape` or have trained at least with one forward'
293298

299+
# clipping still helps for predict noise objective
300+
# much like original ddpm paper trick
301+
302+
maybe_clip = (lambda t: t.clamp_(*self.clip_values)) if self.clip_during_sampling else identity
303+
304+
# ode step function
305+
294306
def ode_fn(t, x):
307+
x = maybe_clip(x)
308+
295309
_, flow = self.predict_flow(model, x, times = t, **model_kwargs)
296310
return flow
297311

0 commit comments

Comments
 (0)