Skip to content

Commit 92f73ff

Browse files
committed
able to ues the consistency EMA as the main EMA being sampled from the Trainer. remove wip as things are working
1 parent 8cc8bd3 commit 92f73ff

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<img src="./rf.png" width="400px"></img>
22

3-
## Rectified Flow - Pytorch (wip)
3+
## Rectified Flow - Pytorch
44

55
Implementation of <a href="https://www.cs.utexas.edu/~lqiang/rectflow/html/intro.html">rectified flow</a> and some of its followup research / improvements in Pytorch
66

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.22"
3+
version = "0.0.23"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -824,19 +824,25 @@ def __init__(
824824
num_samples: int = 16,
825825
adam_kwargs: dict = dict(),
826826
accelerate_kwargs: dict = dict(),
827-
ema_kwargs: dict = dict()
827+
ema_kwargs: dict = dict(),
828+
use_consistency_ema = False # whether to just use the EMA from the velocity consistency from the consistency FM paper
828829
):
829830
super().__init__()
830831
self.accelerator = Accelerator(**accelerate_kwargs)
831832

832833
self.model = rectified_flow
833834

834835
if self.is_main:
835-
self.ema_model = EMA(
836-
self.model,
837-
forward_method_names = ('sample',),
838-
**ema_kwargs
839-
)
836+
if use_consistency_ema:
837+
assert self.model.use_consistency, 'model must be using the consistency EMA for it to be reused as the main EMA model during sampling'
838+
839+
self.ema_model = self.model.ema_model
840+
else:
841+
self.ema_model = EMA(
842+
self.model,
843+
forward_method_names = ('sample',),
844+
**ema_kwargs
845+
)
840846

841847
self.ema_model.to(self.accelerator.device)
842848

0 commit comments

Comments
 (0)