Skip to content

Commit 89584d7

Browse files
committed
allow for passing in own frozen model during reflow, for testing out a personal idea using EMA
1 parent 7c821a7 commit 89584d7

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.3"
3+
version = "0.0.4"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ class Reflow(Module):
155155
def __init__(
156156
self,
157157
rectified_flow: RectifiedFlow,
158+
frozen_model: RectifiedFlow | None = None,
158159
*,
159160
batch_size = 16,
160161

@@ -168,12 +169,15 @@ def __init__(
168169

169170
self.model = rectified_flow
170171

171-
# make a frozen copy of the model and set requires grad to be False for all parameters for safe measure
172+
if not exists(frozen_model):
173+
# make a frozen copy of the model and set requires grad to be False for all parameters for safe measure
172174

173-
self.frozen_model = deepcopy(rectified_flow)
175+
frozen_model = deepcopy(rectified_flow)
174176

175-
for p in self.frozen_model.parameters():
176-
p.detach_()
177+
for p in frozen_model.parameters():
178+
p.detach_()
179+
180+
self.frozen_model = frozen_model
177181

178182
def parameters(self):
179183
return self.model.parameters() # omit frozen model

0 commit comments

Comments
 (0)