Skip to content

Commit

Permalink
just take care of the logic for AdamW and transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 29, 2022
1 parent 39d3659 commit 846162e
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 1 deletion.
84 changes: 84 additions & 0 deletions dalle2_pytorch/openai_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
from PIL import Image

from dalle2_pytorch.dalle2_pytorch import BaseClipAdapter
import torchvision.transforms as T

def find_layer(model, layer):
modules = dict([*model.named_modules()])
return modules.get(layer, None)

def hook(_, input, output):
print(output.shape)

import clip
# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).cuda()
image = torch.randn(1, 3, 224, 224).cuda()


class OpenAIClipAdapter(BaseClipAdapter):
def __init__(self, name = 'ViT-B/32'):
try:
import clip
except ImportError:
print('you must install openai clip in order to use this adapter - `pip install git+https://github.com/openai/CLIP.git` - more instructions at https://github.com/openai/CLIP#usage')

openai_clip, _ = clip.load(name)
super().__init__(openai_clip)

text_attention_final = self.find_layer(self.clip, 'ln_final')
self.handle = text_attention_final.register_forward_hook(self._hook)
self.clip_normalize = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
self.cleared = False

def find_layer(self, layer):
modules = dict([*self.clip.named_modules()])
return modules.get(layer, None)

def clear(self):
if self.cleared:
return

self.handle()

def _hook(self, _, inputs, outputs):
self.text_encodings = outputs

@property
def dim_latent(self):
return 512

@property
def image_size(self):
return self.clip.visual.input_resolution

@property
def image_channels(self):
return 3

@torch.no_grad()
def embed_text(self, text):
assert not self.cleared

text_embed = self.clip.encode_text(text)
text_encodings = self.text_encodings
del self.text_encodings
return text_embed, text_encodings

@torch.no_grad()
def embed_image(self, image):
assert not self.cleared

image = self.clip_normalize(image)
image_embed = self.clip.encode_image(image)
return image_embed, None

clip_adapter = OpenAIClipAdapter().cuda()

# print(model)
with torch.no_grad():
image_features, _ = clip_adapter.embed_image(image)
text_features, text_encodings = clip_adapter.embed_text(text)
print(text_features.shape, image_features.shape)
print(text_encodings.shape)
29 changes: 29 additions & 0 deletions dalle2_pytorch/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from torch.optim import AdamW, Adam

def separate_weight_decayable_params(params):
no_wd_params = set([param for param in params if param.ndim < 2])
wd_params = set(params) - no_wd_params
return wd_params, no_wd_params

def get_optimizer(
params,
lr = 3e-4,
wd = 1e-2,
betas = (0.9, 0.999),
filter_by_requires_grad = False
):
if filter_by_requires_grad:
params = list(filter(lambda t: t.requires_grad, params))

if wd == 0:
return Adam(params, lr = lr, betas = betas)

params = set(params)
wd_params, no_wd_params = separate_weight_decayable_params(params)

param_groups = [
{'params': list(wd_params)},
{'params': list(no_wd_params), 'weight_decay': 0},
]

return AdamW(param_groups, lr = lr, weight_decay = wd, betas = betas)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.67',
version = '0.0.70',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
Expand Down

0 comments on commit 846162e

Please sign in to comment.