From 5c25a2dba0464732a255ec8017a719bdcce3bffc Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 20 Aug 2024 09:56:26 -0700 Subject: [PATCH] final tweak --- pyproject.toml | 2 +- rectified_flow_pytorch/rectified_flow.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0bca6a5..5a48fc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rectified-flow-pytorch" -version = "0.1.7" +version = "0.1.8" description = "Rectified Flow in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } diff --git a/rectified_flow_pytorch/rectified_flow.py b/rectified_flow_pytorch/rectified_flow.py index 88f7f16..b4ea7d0 100644 --- a/rectified_flow_pytorch/rectified_flow.py +++ b/rectified_flow_pytorch/rectified_flow.py @@ -155,7 +155,7 @@ def __init__( data_unnormalize_fn = unnormalize_to_zero_to_one, clip_during_sampling = False, clip_values: Tuple[float, float] = (-1., 1.), - clip_flow_during_sampling = False, # this seems to help a lot when training with predict epsilon, at least for me + clip_flow_during_sampling = None, # this seems to help a lot when training with predict epsilon, at least for me clip_flow_values: Tuple[float, float] = (-3., 3) ): super().__init__() @@ -170,6 +170,10 @@ def __init__( self.predict = predict + # automatically default to a working setting for predict epsilon + + clip_flow_during_sampling = default(clip_flow_during_sampling, predict == 'noise') + # loss fn if loss_fn == 'mse':