From c59ebf49eb8f0e01d773c1ffd7d329079c6a384a Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 19 Feb 2024 06:40:51 -0800 Subject: [PATCH] address https://github.com/lucidrains/denoising-diffusion-pytorch/issues/293 --- denoising_diffusion_pytorch/karras_unet.py | 17 +++++++++++------ denoising_diffusion_pytorch/version.py | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index e73c663a4..69b9d26e4 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -124,6 +124,12 @@ def forward(self, x): # forced weight normed conv2d and linear # algorithm 1 in paper +def normalize_weight(weight, eps = 1e-4): + weight, ps = pack_one(weight, 'o *') + normed_weight = l2norm(weight, eps = eps) + normed_weight = normed_weight * sqrt(weight.numel() / weight.shape[0]) + return unpack_one(normed_weight, ps, 'o *') + class Conv2d(Module): def __init__( self, @@ -142,14 +148,13 @@ def __init__( self.concat_ones_to_input = concat_ones_to_input def forward(self, x): + if self.training: with torch.no_grad(): - weight, ps = pack_one(self.weight, 'o *') - normed_weight = l2norm(weight, eps = self.eps) - normed_weight = unpack_one(normed_weight, ps, 'o *') + normed_weight = normalize_weight(self.weight, eps = self.eps) self.weight.copy_(normed_weight) - weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in) + weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in) if self.concat_ones_to_input: x = F.pad(x, (0, 0, 0, 0, 1, 0), value = 1.) @@ -167,10 +172,10 @@ def __init__(self, dim_in, dim_out, eps = 1e-4): def forward(self, x): if self.training: with torch.no_grad(): - normed_weight = l2norm(self.weight, eps = self.eps) + normed_weight = normalize_weight(self.weight, eps = self.eps) self.weight.copy_(normed_weight) - weight = l2norm(self.weight, eps = self.eps) / sqrt(self.fan_in) + weight = normalize_weight(self.weight, eps = self.eps) / sqrt(self.fan_in) return F.linear(x, weight) # mp fourier embeds diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index 075f2f55a..461bcf588 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.10' +__version__ = '1.10.12'