From a81bf923b962fd19de889f04e5ba2de167ebbf86 Mon Sep 17 00:00:00 2001 From: Yotam Nitzan Date: Wed, 27 Jul 2022 21:39:07 -0700 Subject: [PATCH] Turn off gradients to CLIP's weights --- ZSSGAN/criteria/clip_loss.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ZSSGAN/criteria/clip_loss.py b/ZSSGAN/criteria/clip_loss.py index 42d8a77..5cccf75 100644 --- a/ZSSGAN/criteria/clip_loss.py +++ b/ZSSGAN/criteria/clip_loss.py @@ -64,6 +64,9 @@ def __init__(self, device, lambda_direction=1., lambda_patch=0., lambda_global=0 preprocess_cnn.transforms[:2] + # to match CLIP input scale assumptions preprocess_cnn.transforms[4:]) # + skip convert PIL to tensor + self.model.requires_grad_(False) + self.model_cnn.requires_grad_(False) + self.texture_loss = torch.nn.MSELoss() def tokenize(self, strings: list):