Skip to content

Commit ed7f2cd

Browse files
committed
able to use one nested dictionary to instantiate everything
1 parent 553ef25 commit ed7f2cd

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rectified-flow-pytorch"
3-
version = "0.0.28"
3+
version = "0.0.29"
44
description = "Rectified Flow in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

rectified_flow_pytorch/rectified_flow.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def forward(self, pred, target, **kwargs):
125125
class RectifiedFlow(Module):
126126
def __init__(
127127
self,
128-
model: Module,
128+
model: dict | Module,
129129
time_cond_kwarg: str | None = 'times',
130130
odeint_kwargs: dict = dict(
131131
atol = 1e-5,
@@ -154,6 +154,10 @@ def __init__(
154154
data_unnormalize_fn = unnormalize_to_zero_to_one
155155
):
156156
super().__init__()
157+
158+
if isinstance(model, dict):
159+
model = Unet(**model)
160+
157161
self.model = model
158162
self.time_cond_kwarg = time_cond_kwarg # whether the model is to be conditioned on the times
159163

@@ -832,9 +836,9 @@ def cycle(dl):
832836
class Trainer(Module):
833837
def __init__(
834838
self,
835-
rectified_flow: RectifiedFlow,
839+
rectified_flow: dict | RectifiedFlow,
836840
*,
837-
dataset: Dataset,
841+
dataset: dict | Dataset,
838842
num_train_steps = 70_000,
839843
learning_rate = 3e-4,
840844
batch_size = 16,
@@ -851,6 +855,12 @@ def __init__(
851855
super().__init__()
852856
self.accelerator = Accelerator(**accelerate_kwargs)
853857

858+
if isinstance(dataset, dict):
859+
dataset = ImageDataset(**dataset)
860+
861+
if isinstance(rectified_flow, dict):
862+
rectified_flow = RectifiedFlow(**rectified_flow)
863+
854864
self.model = rectified_flow
855865

856866
# determine whether to keep track of EMA (if not using consistency FM)

0 commit comments

Comments
 (0)