diff --git a/ZSSGAN/criteria/clip_loss.py b/ZSSGAN/criteria/clip_loss.py index 42d8a77..5cccf75 100644 --- a/ZSSGAN/criteria/clip_loss.py +++ b/ZSSGAN/criteria/clip_loss.py @@ -64,6 +64,9 @@ def __init__(self, device, lambda_direction=1., lambda_patch=0., lambda_global=0 preprocess_cnn.transforms[:2] + # to match CLIP input scale assumptions preprocess_cnn.transforms[4:]) # + skip convert PIL to tensor + self.model.requires_grad_(False) + self.model_cnn.requires_grad_(False) + self.texture_loss = torch.nn.MSELoss() def tokenize(self, strings: list):