From 41fabf29220d0469af9c2681068e2ef99caa0085 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 19 Oct 2022 09:26:00 -0700 Subject: [PATCH] fix a dtype conversion issue for the diffusion timesteps in the diffusion prior, thanks to @JiaHeng-DLUT --- dalle2_pytorch/dalle2_pytorch.py | 10 ++++++++++ dalle2_pytorch/version.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 2c179089..6491e2c9 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -100,6 +100,9 @@ def inner(model, *args, **kwargs): return out return inner +def is_float_dtype(dtype): + return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)]) + def is_list_str(x): if not isinstance(x, (list, tuple)): return False @@ -968,6 +971,8 @@ def __init__( Rearrange('b (n d) -> b n d', n = num_text_embeds) ) + self.continuous_embedded_time = not exists(num_timesteps) + self.to_time_embeds = nn.Sequential( nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP Rearrange('b (n d) -> b n d', n = num_time_embeds) @@ -1095,6 +1100,9 @@ def forward( # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # but let's just do it right + if self.continuous_embedded_time: + diffusion_timesteps = diffusion_timesteps.type(dtype) + time_embed = self.to_time_embeds(diffusion_timesteps) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) @@ -1538,6 +1546,8 @@ def __init__(self, dim): def forward(self, x): dtype, device = x.dtype, x.device + assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type' + half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 8047ba3f..192e7e0f 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.7' +__version__ = '1.10.8'