Skip to content

Commit

Permalink
force first unet in the cascade to be conditioned on image embeds
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 29, 2022
1 parent cb26187 commit aa90021
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,13 +1066,14 @@ def cast_model_parameters(
self,
*,
lowres_cond,
channels
channels,
cond_on_image_embeds
):
if lowres_cond == self.lowres_cond and channels == self.channels:
if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds:
return self

updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
return self.__class__(**updated_kwargs)
updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds}
return self.__class__(**{**self._locals, **updated_kwargs})

def forward_with_cond_scale(
self,
Expand Down Expand Up @@ -1279,6 +1280,7 @@ def __init__(

one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
cond_on_image_embeds = is_first,
channels = unet_channels
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.63',
version = '0.0.64',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit aa90021

Please sign in to comment.