Skip to content

Commit 8004633

Browse files
committed
make sure entire readme runs without errors
1 parent 36fb46a commit 8004633

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

README.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ decoder = Decoder(
396396
).cuda()
397397

398398
for unet_number in (1, 2):
399-
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
399+
loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
400400
loss.backward()
401401

402402
# do above for many steps
@@ -861,25 +861,23 @@ unet1 = Unet(
861861
text_embed_dim = 512,
862862
cond_dim = 128,
863863
channels = 3,
864-
dim_mults=(1, 2, 4, 8)
864+
dim_mults=(1, 2, 4, 8),
865+
cond_on_text_encodings = True,
865866
).cuda()
866867

867868
unet2 = Unet(
868869
dim = 16,
869870
image_embed_dim = 512,
870-
text_embed_dim = 512,
871871
cond_dim = 128,
872872
channels = 3,
873873
dim_mults = (1, 2, 4, 8, 16),
874-
cond_on_text_encodings = True
875874
).cuda()
876875

877876
decoder = Decoder(
878877
unet = (unet1, unet2),
879878
image_sizes = (128, 256),
880879
clip = clip,
881-
timesteps = 1000,
882-
condition_on_text_encodings = True
880+
timesteps = 1000
883881
).cuda()
884882

885883
decoder_trainer = DecoderTrainer(
@@ -904,8 +902,8 @@ for unet_number in (1, 2):
904902
# after much training
905903
# you can sample from the exponentially moving averaged unets as so
906904

907-
mock_image_embed = torch.randn(4, 512).cuda()
908-
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
905+
mock_image_embed = torch.randn(32, 512).cuda()
906+
images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
909907
```
910908

911909
### Diffusion Prior Training

dalle2_pytorch/dalle2_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1831,7 +1831,7 @@ def cast_model_parameters(
18311831
channels == self.channels and \
18321832
cond_on_image_embeds == self.cond_on_image_embeds and \
18331833
cond_on_text_encodings == self.cond_on_text_encodings and \
1834-
cond_on_lowres_noise == self.cond_on_lowres_noise and \
1834+
lowres_noise_cond == self.lowres_noise_cond and \
18351835
channels_out == self.channels_out:
18361836
return self
18371837

dalle2_pytorch/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ class DiffusionPriorTrainer(nn.Module):
174174
def __init__(
175175
self,
176176
diffusion_prior,
177-
accelerator,
177+
accelerator = None,
178178
use_ema = True,
179179
lr = 3e-4,
180180
wd = 1e-2,
@@ -186,8 +186,12 @@ def __init__(
186186
):
187187
super().__init__()
188188
assert isinstance(diffusion_prior, DiffusionPrior)
189-
assert isinstance(accelerator, Accelerator)
189+
190190
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
191+
accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
192+
193+
if not exists(accelerator):
194+
accelerator = Accelerator(**accelerator_kwargs)
191195

192196
# assign some helpful member vars
193197

dalle2_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.1'
1+
__version__ = '1.2.2'

0 commit comments

Comments
 (0)