Skip to content

reconfiguration cae #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions passl/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def worker_init_fn(worker_id):
""" set seed in subproces for dataloader when num_workers > 0"""
np.random.seed(seed + worker_id)
random.seed(seed + worker_id)
else:
worker_init_fn = None
logger.warning('seed is not set in config and work_init_fn will be set to None!')

RELATED_FLAGS_SETTING = {}
RELATED_FLAGS_SETTING['FLAGS_cudnn_exhaustive_search'] = 1
Expand Down
167 changes: 160 additions & 7 deletions passl/models/cae.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import math
import time
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
Expand All @@ -24,9 +25,14 @@
from .vision_transformer import PatchEmbed, DropPath
from passl.models.base_model import Model
from passl.nn import init
from passl.utils import logger
from scipy import interpolate
from passl.models.utils.pos_embed import interpolate_pos_embed


__all__ = [
'CAEPretrain',
'cae_base_patch16_224',
'cae_small_patch16_224_8k_vocab',
'cae_base_patch16_224_8k_vocab',
'cae_large_patch16_224_8k_vocab',
Expand Down Expand Up @@ -1101,12 +1107,15 @@ def __init__(self,
use_mean_pooling=True,
init_scale=0.001,
lin_probe=False,
sin_pos_emb=False,
linear_type='standard',
linear_depth=1,
args=None):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.use_mean_pooling = use_mean_pooling

self.sin_pos_emb = sin_pos_emb
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
Expand All @@ -1119,7 +1128,7 @@ def __init__(self,
if use_abs_pos_emb:
self.pos_embed = self.create_parameter(
[1, num_patches + 1, embed_dim])
elif args.sin_pos_emb:
elif self.sin_pos_emb:
# sine-cosine positional embeddings is on the way
self.pos_embed = self.create_parameter(
[1, num_patches + 1, embed_dim])
Expand Down Expand Up @@ -1165,10 +1174,11 @@ def __init__(self,
self.lin_probe = lin_probe
# NOTE: batch norm
self.args = args
self.linear_type = args.linear_type
self.linear_type = linear_type
self.linear_depth = linear_depth
if lin_probe:
if args.linear_type != 'standard':
if args.linear_type == 'attentive_no_parameter':
if self.linear_type != 'standard':
if self.linear_type == 'attentive_no_parameter':
no_parameter = True
else:
no_parameter = False
Expand All @@ -1186,7 +1196,7 @@ def __init__(self,
norm_layer=norm_layer,
init_values=0,
no_parameter=no_parameter)
for i in range(args.linear_depth)
for i in range(self.linear_depth)
])

self.query_token = self.create_parameter([1, 1, embed_dim])
Expand All @@ -1207,6 +1217,21 @@ def __init__(self,
self.head.weight.set_value(self.head.weight * init_scale)
self.head.bias.set_value(self.head.bias * init_scale)

self.head = paddle.nn.Sequential(
paddle.nn.BatchNorm1D(
self.head.weight.shape[0],
epsilon=1e-6,
weight_attr=False,
bias_attr=False),
self.head)

for _, p in self.named_parameters():
p.stop_gradient = True

for _, p in self.head[1].named_parameters():
p.stop_gradient = False


def build_2d_sincos_position_embedding(self,
embed_dim=768,
temperature=10000.):
Expand Down Expand Up @@ -1279,6 +1304,7 @@ def forward_features(self, x, is_train=True):
[batch_size, -1,
-1]) # stole cls_tokens impl from Phil Wang, thanks
x = paddle.concat((cls_tokens, x), axis=1)

if self.pos_embed is not None:
if self.use_abs_pos_emb:
x = x + self.pos_embed.expand(
Expand Down Expand Up @@ -1320,10 +1346,130 @@ def forward_features(self, x, is_train=True):
return query_tokens[:, 0, :]

def forward(self, x, is_train=True):
print('input: ', x.detach())
x = self.forward_features(x, is_train)
x = self.head(x)
print('forward_features: ', x.detach().sum().cpu().numpy())
for na, layer in self.head._sub_layers.items():
print(x.detach().mean().cpu().numpy())
x = layer(x)
for n, p in layer.named_parameters():
print(n, p.detach().sum().cpu().numpy())
print(na, ':', x.detach().sum().cpu().numpy())
return x

def load_pretrained(self, path, rank=0, finetune=False):
checkpoint = paddle.load(path)

logger.info("Load pre-trained checkpoint from: %s" % path)
checkpoint_model = checkpoint['model']
state_dict = self.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and list(checkpoint_model[
k].shape) != list(state_dict[k].shape):
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]

for key in list(checkpoint_model.keys()):
if 'encoder.' in key:
new_key = key.replace('encoder.', '')
checkpoint_model[new_key] = checkpoint_model[key]
checkpoint_model.pop(key)
if 'teacher' in key or 'decoder' in key:
checkpoint_model.pop(key)

if self.rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
print(
"Expand the shared relative position embedding to each transformer block. "
)
num_layers = self.get_num_layers()
rel_pos_bias = checkpoint_model[
"rel_pos_bias.relative_position_bias_table"]
for i in range(num_layers):
checkpoint_model["blocks.%d.attn.relative_position_bias_table"
% i] = rel_pos_bias.clone()

checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")

all_keys = list(checkpoint_model.keys())

for key in all_keys:
if "relative_position_index" in key:
checkpoint_model.pop(key)

if "relative_position_bias_table" in key and self.rel_pos_bias:
rel_pos_bias = checkpoint_model[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = self.state_dict()[key].size()
dst_patch_shape = self.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (
dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
if src_size != dst_size:
print("Position interpolate for %s from %dx%d to %dx%d" %
(key, src_size, src_size, dst_size, dst_size))
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]

def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)

left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q

dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)

r_ids = [-_ for _ in reversed(dis)]

x = r_ids + [0] + dis
y = r_ids + [0] + dis

t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)

print("Original positions = %s" % str(x))
print("Target positions = %s" % str(dx))

all_rel_pos_bias = []

for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size,
src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
paddle.to_tensor(f(dx, dy)).contiguous().view(-1,
1))

rel_pos_bias = paddle.concat(all_rel_pos_bias, dim=-1)

new_rel_pos_bias = paddle.concat(
(rel_pos_bias, extra_tokens), dim=0)
checkpoint_model[key] = new_rel_pos_bias

# interpolate position embedding
interpolate_pos_embed(self, checkpoint_model)

# load pre-trained model
self.set_state_dict(checkpoint_model)
init.trunc_normal_(self.head[1].weight, std=0.01)


def save(self, path, local_rank=0, rank=0):
paddle.save(self.state_dict(), path + ".pdparams")


def cae_small_patch16_224_8k_vocab(**kwargs):
model = CAEPretrain(
Expand Down Expand Up @@ -1392,6 +1538,13 @@ def cae_base_patch16_224(**kwargs):
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
sin_pos_emb=True,
lin_probe=True,
use_mean_pooling=False,
use_rel_pos_bias=False,
use_abs_pos_emb=False,
init_scale=0.001,
init_values=0.1,
norm_layer=partial(
nn.LayerNorm, epsilon=1e-6, weight_attr=True, bias_attr=True),
**kwargs)
Expand Down
29 changes: 14 additions & 15 deletions passl/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,25 @@ def build_optimizer(config, lr_scheduler, model, epochs, step_each_epoch, lr_dec
if hasattr(model, 'param_group_fn'):
# param groups are defined by model
model_group_cfg = config.pop('param_group_fn', {})
param_group_map = model.param_group_fn(no_weight_decay_name=no_weight_decay_name, weight_decay=weight_decay,
layer_decay=layer_decay, **model_group_cfg)
param_group_map = model.param_group_fn(**model_group_cfg)
else:
param_groups_cfg = config.get('param_groups', None)
if param_groups_cfg and len(param_groups_cfg) > 0:
param_groups_cfg = build_group_lr_scheduler(param_groups_cfg, epochs, step_each_epoch, lr_decay_unit)
param_group_map = group_params(model, param_groups_cfg)
if isinstance(layer_decay, float):
param_group_map = param_group_layer_decay(model,
layer_decay,
weight_decay=weight_decay,
param_groups_map=param_group_map,
no_weight_decay_list=no_weight_decay_name,
)
elif len(no_weight_decay_name) > 0:
param_group_map = param_group_weight_decay(model,
weight_decay=weight_decay,
param_groups_map=param_group_map,
no_weight_decay_list=no_weight_decay_name,
)
if isinstance(layer_decay, float):
param_group_map = param_group_layer_decay(model,
layer_decay,
weight_decay=weight_decay,
param_groups_map=param_group_map,
no_weight_decay_list=no_weight_decay_name,
)
elif len(no_weight_decay_name) > 0:
param_group_map = param_group_weight_decay(model,
weight_decay=weight_decay,
param_groups_map=param_group_map,
no_weight_decay_list=no_weight_decay_name,
)

for key in param_group_map:
param_group_map[key]['params'] = [p for (n, p) in param_group_map[key]['params']]
Expand Down
Loading