@@ -125,7 +125,7 @@ def forward(self, pred, target, **kwargs):
125
125
class RectifiedFlow (Module ):
126
126
def __init__ (
127
127
self ,
128
- model : Module ,
128
+ model : dict | Module ,
129
129
time_cond_kwarg : str | None = 'times' ,
130
130
odeint_kwargs : dict = dict (
131
131
atol = 1e-5 ,
@@ -154,6 +154,10 @@ def __init__(
154
154
data_unnormalize_fn = unnormalize_to_zero_to_one
155
155
):
156
156
super ().__init__ ()
157
+
158
+ if isinstance (model , dict ):
159
+ model = Unet (** model )
160
+
157
161
self .model = model
158
162
self .time_cond_kwarg = time_cond_kwarg # whether the model is to be conditioned on the times
159
163
@@ -832,9 +836,9 @@ def cycle(dl):
832
836
class Trainer (Module ):
833
837
def __init__ (
834
838
self ,
835
- rectified_flow : RectifiedFlow ,
839
+ rectified_flow : dict | RectifiedFlow ,
836
840
* ,
837
- dataset : Dataset ,
841
+ dataset : dict | Dataset ,
838
842
num_train_steps = 70_000 ,
839
843
learning_rate = 3e-4 ,
840
844
batch_size = 16 ,
@@ -851,6 +855,12 @@ def __init__(
851
855
super ().__init__ ()
852
856
self .accelerator = Accelerator (** accelerate_kwargs )
853
857
858
+ if isinstance (dataset , dict ):
859
+ dataset = ImageDataset (** dataset )
860
+
861
+ if isinstance (rectified_flow , dict ):
862
+ rectified_flow = RectifiedFlow (** rectified_flow )
863
+
854
864
self .model = rectified_flow
855
865
856
866
# determine whether to keep track of EMA (if not using consistency FM)
0 commit comments