Skip to content

Commit

Permalink
final tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 20, 2024
1 parent cc6efa4 commit 5c25a2d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rectified-flow-pytorch"
version = "0.1.7"
version = "0.1.8"
description = "Rectified Flow in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
6 changes: 5 additions & 1 deletion rectified_flow_pytorch/rectified_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(
data_unnormalize_fn = unnormalize_to_zero_to_one,
clip_during_sampling = False,
clip_values: Tuple[float, float] = (-1., 1.),
clip_flow_during_sampling = False, # this seems to help a lot when training with predict epsilon, at least for me
clip_flow_during_sampling = None, # this seems to help a lot when training with predict epsilon, at least for me
clip_flow_values: Tuple[float, float] = (-3., 3)
):
super().__init__()
Expand All @@ -170,6 +170,10 @@ def __init__(

self.predict = predict

# automatically default to a working setting for predict epsilon

clip_flow_during_sampling = default(clip_flow_during_sampling, predict == 'noise')

# loss fn

if loss_fn == 'mse':
Expand Down

0 comments on commit 5c25a2d

Please sign in to comment.