File tree 2 files changed +9
-5
lines changed
2 files changed +9
-5
lines changed Original file line number Diff line number Diff line change 1
1
[project ]
2
2
name = " rectified-flow-pytorch"
3
- version = " 0.0.3 "
3
+ version = " 0.0.4 "
4
4
description = " Rectified Flow in Pytorch"
5
5
authors = [
6
6
{
name =
" Phil Wang" ,
email =
" [email protected] " }
Original file line number Diff line number Diff line change @@ -155,6 +155,7 @@ class Reflow(Module):
155
155
def __init__ (
156
156
self ,
157
157
rectified_flow : RectifiedFlow ,
158
+ frozen_model : RectifiedFlow | None = None ,
158
159
* ,
159
160
batch_size = 16 ,
160
161
@@ -168,12 +169,15 @@ def __init__(
168
169
169
170
self .model = rectified_flow
170
171
171
- # make a frozen copy of the model and set requires grad to be False for all parameters for safe measure
172
+ if not exists (frozen_model ):
173
+ # make a frozen copy of the model and set requires grad to be False for all parameters for safe measure
172
174
173
- self . frozen_model = deepcopy (rectified_flow )
175
+ frozen_model = deepcopy (rectified_flow )
174
176
175
- for p in self .frozen_model .parameters ():
176
- p .detach_ ()
177
+ for p in frozen_model .parameters ():
178
+ p .detach_ ()
179
+
180
+ self .frozen_model = frozen_model
177
181
178
182
def parameters (self ):
179
183
return self .model .parameters () # omit frozen model
You can’t perform that action at this time.
0 commit comments