diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 78e9b4dc25..a96c2d3a61 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -75,6 +75,7 @@ from .model_zoo.convnext import ConvNeXt_tiny, ConvNeXt_small, ConvNeXt_base_224, ConvNeXt_base_384, ConvNeXt_large_224, ConvNeXt_large_384 from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_large_224, NextViT_small_384, NextViT_base_384, NextViT_large_384 from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224 +from .model_zoo.ibot import IBOT,IBOT_ViT_small_patch16_224,IBOT_ViT_base_patch16_224,IBOT_ViT_large_patch16_224 from .variant_models.resnet_variant import ResNet50_last_stage_stride1 from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d diff --git a/ppcls/arch/backbone/model_zoo/ibot.py b/ppcls/arch/backbone/model_zoo/ibot.py new file mode 100644 index 0000000000..c8df8574b0 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/ibot.py @@ -0,0 +1,597 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Code was based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# reference: https://arxiv.org/abs/2010.11929 + +from collections.abc import Callable +import math +import numpy as np +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +from .vision_transformer import VisionTransformer, Identity, trunc_normal_ +import os +from ..legendary_models.swin_transformer import SwinTransformer +from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "IBOT_ViT_small_patch16_224": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_small_patch16_224_pretrained.pdparams", + "IBOT_ViT_base_patch16_224": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_224_pretrained.pdparams", + "IBOT_ViT_large_patch16_224": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch16_384_pretrained.pdparams", + "IBOT_Swin_tiny_patch7_224": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_base_patch32_384_pretrained.pdparams", + "IBOT_Swin_tiny_patch14_224": "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ViT_large_patch16_224_pretrained.pdparams", +} + +__all__ = list(MODEL_URLS.keys()) +normal_ = Normal +zeros_ = Constant(value=0.0) +ones_ = Constant(value=1.0) + +def to_2tuple(x): + return tuple([x] * 2) + +class IBOT_PatchEmbed(nn.Layer): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * \ + (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2D( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + pass + def forward(self, x): + B, C, H, W = x.shape + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + # x = self.proj(x).flatten(2).transpose((0, 2, 1)) + x = self.proj(x) + return x + +class MultiCropWrapper(nn.Layer): + """ + Perform forward pass separately on each resolution input. + The inputs corresponding to a single resolution are clubbed and single + forward is run on the same resolution inputs. Hence we do several + forward passes = number of different resolutions used. We then + concatenate all the output features and run the head forward on these + concatenated features. + """ + + def __init__(self, backbone, head=None): + super(MultiCropWrapper, self).__init__() + # disable layers dedicated to ImageNet labels classification + backbone.fc, backbone.head = nn.Identity(), nn.Identity() + self.backbone = backbone + if head is None: + self.head = nn.Identity() + else: + self.head = head + + def forward(self, x, mask=None, return_backbone_feat=False, **kwargs): + # convert to list + if not isinstance(x, list): + x = [x] + mask = [mask] if mask is not None else None + idx_crops = paddle.cumsum( + paddle.unique_consecutive( + paddle.to_tensor([inp.shape[-1] for inp in x]), + return_counts=True, + )[1], + 0, + ) + + start_idx, output = 0, paddle.empty((0,)) + for end_idx in idx_crops: + inp_x = paddle.concat(x[start_idx:end_idx]) + + if mask is not None: + inp_m = paddle.concat(mask[start_idx:end_idx]) + kwargs.update(dict(mask=inp_m)) + + _out = self.backbone(inp_x, **kwargs) + if start_idx == 0: + output = _out + else: + output = paddle.concat((output, _out)) + start_idx = end_idx + + # Run the head forward on the concatenated features. + output_ = self.head(output) + if return_backbone_feat: + return output, output_ + return output_ + + +class DINOHead(nn.Layer): + def __init__( + self, + in_dim, + out_dim, + norm=None, + act_layer=nn.GELU, + last_norm=None, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + norm_last_layer=True, + epsilon=1e-5, + **kwargs + ): + super().__init__() + if norm is not None: + self.norm = eval(norm)(hidden_dim, epsilon=epsilon) + if last_norm is not None: + self.last_norm = eval(last_norm)(out_dim, epsilon=epsilon) + else: + self.last_norm = None + if act_layer is not None: + self.act = act_layer() + + nlayers = max(nlayers, 1) + if nlayers == 1: + if bottleneck_dim > 0: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + self.mlp = nn.Linear(in_dim, out_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if norm is not None: + layers.append(norm) + layers.append(self.act) + + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if norm is not None: + layers.append(norm) + layers.append(self.act) + if bottleneck_dim > 0: + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + else: + layers.append(nn.Linear(hidden_dim, out_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + + if bottleneck_dim > 0: + self.last_layer = nn.utils.weight_norm( + nn.Linear(bottleneck_dim, out_dim, bias_attr=False), dim=1 + ) + ones_(self.last_layer.weight_g) + if norm_last_layer: + self.last_layer.weight_g.stop_gradient = False + + else: + self.last_layer = None + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + + def forward(self, x): + x = self.mlp(x) + if self.last_layer is not None: + x = nn.functional.normalize(x, axis=-1, p=2) + x = self.last_layer(x) + if self.last_norm is not None: + x = self.last_norm(x) + return x + + +class IBOTHead(DINOHead): + def __init__( + self, + *args, + patch_out_dim=8192, + norm=None, + act_layer=nn.GELU, + last_norm=None, + nlayers=3, + epsilon=1e-5, + hidden_dim=2048, + bottleneck_dim=256, + norm_last_layer=True, + shared_head=False, + **kwargs + ): + super(IBOTHead, self).__init__( + *args, + norm=norm, + act_layer=act_layer, + last_norm=last_norm, + nlayers=nlayers, + hidden_dim=hidden_dim, + bottleneck_dim=bottleneck_dim, + norm_last_layer=norm_last_layer, + **kwargs + ) + if not shared_head: + if bottleneck_dim > 0: + self.last_layer2 = nn.utils.weight_norm( + nn.Linear(bottleneck_dim, patch_out_dim, bias_attr=False), dim=1 + ) + ones_(self.last_layer2.weight_g) + if norm_last_layer: + self.last_layer2.weight_g.stop_gradient = False + else: + self.mlp2 = nn.Linear(hidden_dim, patch_out_dim) + self.last_layer2 = None + + if last_norm is not None: + self.last_norm2 = eval(last_norm)(patch_out_dim, epsilon=epsilon) + else: + if bottleneck_dim > 0: + self.last_layer2 = self.last_layer + else: + self.mlp2 = self.mlp[-1] + self.last_layer2 = None + if last_norm is not None: + self.last_norm2 = self.last_norm + + def forward(self, x): + if len(x.shape) == 2: + return super(IBOTHead, self).forward(x) + + if self.last_layer is not None: + x = self.mlp(x) + x = nn.functional.normalize(x, axis=-1, p=2) + x1 = self.last_layer(x[:, 0]) + x2 = self.last_layer2(x[:, 1:]) + else: + x = self.mlp[:-1](x) + x1 = self.mlp[-1](x[:, 0]) + x2 = self.mlp2(x[:, 1:]) + + if self.last_norm is not None: + x1 = self.last_norm(x1) + x2 = self.last_norm2(x2) + + return x1, x2 + + +class IBOTVisionTransformer(VisionTransformer): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + class_num=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer="nn.LayerNorm", + epsilon=1e-6, + return_all_tokens=False, + masked_im_modeling=False, + **kwargs + ): + super(IBOTVisionTransformer, self).__init__( + img_size, + patch_size, + in_chans, + class_num, + embed_dim, + depth, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + norm_layer, + epsilon, + **kwargs + ) + self.return_all_tokens = return_all_tokens + self.masked_im_modeling = masked_im_modeling + self.img_size = img_size + self.patch_size = patch_size + self.patch_embed = IBOT_PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + + if self.masked_im_modeling: + self.masked_embed = self.create_parameter( + shape=[1, embed_dim], default_initializer=zeros_ + ) + # trunc_normal_(self.masked_embed) + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size[0] + h0 = h // self.patch_embed.patch_size[1] + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape((1, int(math.sqrt(N)), int(math.sqrt(N)), dim)).transpose((0, 3, 1, 2)), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.transpose((0, 2, 3, 1)).reshape((1, -1, dim)) + return paddle.concat((class_pos_embed.unsqueeze(0), patch_pos_embed), axis=1) + + + def forward_features(self, x, mask=None, return_all_tokens=None): + B, nc, w, h = x.shape + x= self.patch_embed(x) + # x = paddle.transpose(x, perm=[0, 2, 1]) + # C,N,HW = x.shape + # H,W = int(self.img_size/self.patch_size),int(self.img_size/self.patch_size) + # x = x.reshape([C,N,H,W]) + # mask image modeling + if self.masked_im_modeling: + assert mask is not None + x = self.mask_model(x, mask) + x = x.flatten(2).transpose(perm=[0, 2, 1]) + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand((B, -1, -1)).astype(x.dtype) + x = paddle.concat((cls_tokens, x), axis=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + # if self.fc_norm is not None: + # x[:, 0] = self.fc_norm(x[:, 1:, :].mean(1)) + + return_all_tokens = ( + self.return_all_tokens if return_all_tokens is None else return_all_tokens + ) + + if return_all_tokens: + return x + + return x[:, 0] + + def forward(self, x, mask=None): + x = self.forward_features(x, mask, return_all_tokens=self.return_all_tokens) + # x = self.head(x) + + return x + + def mask_model(self, x, mask): + x = paddle.transpose(x, perm=[0, 2, 3, 1]) + x = paddle.where(mask.unsqueeze(-1), paddle.cast(self.masked_embed, x.dtype), x) + x = paddle.transpose(x, perm=[0, 3, 1, 2]) + return x + + def get_intermediate_layers(self, x, n=1,mask=None): + + B, nc, w, h = x.shape + x = self.patch_embed(x) + # mask image modeling + if self.masked_im_modeling: + assert mask is not None + x = self.mask_model(x, mask) + x = x.flatten(2).transpose(perm=[0, 2, 1]) + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand((B, -1, -1)).astype(x.dtype) + x = paddle.concat((cls_tokens, x), axis=1) + x = x + self.interpolate_pos_encoding(x, w, h) + x = self.pos_drop(x) + + # we return the output tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)) + return output + + +def IBOT_ViT_small_patch16_224(patch_size=16, **kwargs): + model = IBOTVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=True, + **kwargs + ) + return model + + +def IBOT_ViT_base_patch16_224(patch_size=16, **kwargs): + model = IBOTVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + **kwargs + ) + + return model + + +def IBOT_ViT_large_patch16_224(patch_size=16, **kwargs): + model = IBOTVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + **kwargs + ) + + return model + + +class LinearClassifier(nn.Layer): + """Linear layer to train on top of frozen features""" + def __init__(self, dim, num_labels=1000): + super(LinearClassifier, self).__init__() + self.num_labels = num_labels + self.linear = nn.Linear(dim, num_labels) + normal_(self.linear.weight) + zeros_(self.linear.bias) + + def forward(self, x): + # flatten + x = x.reshape((x.shape[0], -1)) + + # linear layer + return self.linear(x) + + +class IBOT(nn.Layer): + def __init__(self, **arch_config): + super(IBOT, self).__init__() + assert arch_config['arch'] in ['ViT_small', 'ViT_base','ViT_large'], f"arch can be only ['ViT_small', 'ViT_base','ViT_large']" + model_name = "IBOT_" + arch_config['arch'] + "_patch" + str(arch_config['patch_size']) + "_224" + model_name = eval(model_name) + self.train_stage = arch_config['mode'] + self.arch_config = arch_config + if arch_config['mode'] == 'pretrain': + student = model_name( + patch_size=arch_config["patch_size"], + drop_path_rate=arch_config["drop_path"], + return_all_tokens=True, + masked_im_modeling=arch_config["use_masked_im_modeling"] + ), + student = student[0] + teacher = model_name( + patch_size=arch_config["patch_size"], + return_all_tokens=True, + ), + teacher = teacher[0] + embed_dim = student.embed_dim + # multi-crop wrapper handles forward with inputs of different resolutions + self.student = MultiCropWrapper( + student, + IBOTHead( + embed_dim, + arch_config["out_dim"], + patch_out_dim=arch_config["patch_out_dim"], + norm=arch_config["norm_in_head"], + act=arch_config["act_in_head"], + norm_last_layer=arch_config["norm_last_layer"], + shared_head=arch_config["shared_head"], + ) + ) + self.teacher = MultiCropWrapper( + teacher, + IBOTHead( + embed_dim, + arch_config["out_dim"], + patch_out_dim=arch_config["patch_out_dim"], + norm=arch_config["norm_in_head"], + act=arch_config["act_in_head"], + shared_head=arch_config["shared_head_teacher"], + ) + ) + + # vit_s8 and vit_s16 are batch norm free models. here, we don't check bn + self.teacher = paddle.DataParallel(self.teacher) + self.teacher_without_ddp = self.teacher._layers + self.student = paddle.DataParallel(self.student) + + # teacher and student start with the same weights + self.teacher_without_ddp.load_dict(self.student.state_dict()) + + # there is no backpropagation through the teacher, so no need for gradients + for p in self.teacher.parameters(): + p.stop_gradient = True + + else: + self.model = model_name(patch_size=arch_config['patch_size'], num_classes=0,use_mean_pooling=arch_config["avgpool_patchtokens"]== 1) + feat_dim = self.model.embed_dim * (arch_config['n_last_blocks'] * int(arch_config["avgpool_patchtokens"] != 1) + int(arch_config["avgpool_patchtokens"] > 0) ) + self.model.eval() + for p in self.model.parameters(): + p.stop_gradient = True + + self.linear_clf = paddle.DataParallel(LinearClassifier(feat_dim, arch_config['num_labels'])) + + if os.path.isfile(arch_config['pretrained_weights']): + state_dict = paddle.load(arch_config['pretrained_weights'])[arch_config['checkpoint_key']] + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + new_state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + self.model.set_state_dict(new_state_dict) + + self.n_last_blocks = arch_config['n_last_blocks'] + self.avgpool = arch_config['avgpool_patchtokens'] + + def forward(self, images,masks): + if self.train_stage == 'pretrain': + teacher_output = self.teacher( + images[:self.arch_config["global_crops_number"]]) # only the 2 global views pass through the teacher + student_output = self.student(images[:self.arch_config["global_crops_number"]], + mask=masks[:self.arch_config["global_crops_number"]]) # all views pass through the student + + self.student.sublayers()[0].backbone.masked_im_modeling = False + student_local_cls = self.student(images[self.arch_config["global_crops_number"]:])[0] if len( + images) > self.arch_config["global_crops_number"] else None + self.student.sublayers()[0].backbone.masked_im_modeling = self.arch_config["use_masked_im_modeling"] + + return teacher_output, student_output, student_local_cls + + else: # finetune + self.linear_clf.train() + + # forward + with paddle.no_grad(): + intermediate_output = self.model.get_intermediate_layers(images, self.n_last_blocks) + if self.avgpool == 0: + # norm(x[:, 0]) + output = [x[:, 0] for x in intermediate_output] + elif self.avgpool == 1: + # x[:, 1:].mean(1) + output = [paddle.mean(intermediate_output[-1][:, 1:], axis=1)] + elif self.avgpool == 2: + # norm(x[:, 0]) + x[:, 1:].mean(1) + output = [x[:, 0] for x in intermediate_output] + [ + paddle.mean(intermediate_output[-1][:, 1:], axis=1) + ] + else: + assert False, "Unkown avgpool type {}".format(self.avgpool) + + output = paddle.concat(output, axis=-1) + + return self.linear_clf.forward(output) \ No newline at end of file diff --git a/ppcls/arch/backbone/model_zoo/vision_transformer.py b/ppcls/arch/backbone/model_zoo/vision_transformer.py index fbec1fcb48..b7f0ef6143 100644 --- a/ppcls/arch/backbone/model_zoo/vision_transformer.py +++ b/ppcls/arch/backbone/model_zoo/vision_transformer.py @@ -299,7 +299,7 @@ def _init_weights(self, m): ones_(m.weight) def forward_features(self, x): - # B = x.shape[0] + B = paddle.shape(x)[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand((B, -1, -1)) diff --git a/ppcls/configs/ImageNet/IBOT/IBOT_ViT_small_p16_pretrain.yaml b/ppcls/configs/ImageNet/IBOT/IBOT_ViT_small_p16_pretrain.yaml new file mode 100644 index 0000000000..93e5167b29 --- /dev/null +++ b/ppcls/configs/ImageNet/IBOT/IBOT_ViT_small_p16_pretrain.yaml @@ -0,0 +1,170 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + train_mode: ibot + + warmup_teacher_temp: 0.04 + teacher_temp: 0.07 + warmup_teacher_temp_epochs: 30 + use_fp16: True + weight_decay: 0.04 + weight_decay_end: 0.4 + epochs: 800 + freeze_last_layer: 1 + lr: 0.001 + warmup_epochs: 10 + min_lr: 1e-06 + batch_size: 32 +# num_workers: 10 + global_crops_scale: [ 0.25, 1.0 ] + local_crops_number: 10 + local_crops_scale: [ 0.05, 0.25 ] + pred_ratio: [0, 0.3] + pred_ratio_var: [0 0.2] + + seed: 0 + ngpus: 8 + nodes: 2 + device: gpu +# optimizer: adamw + momentum_teacher: 0.996 + output_dir: ./output_ibot_vit_small_p16/ + + eval_during_train: False + eval_interval: 1 + print_batch_step: 20 + use_visualdl: True + save_interval: 1 + +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + +# model architecture +Arch: + name: IBOT + mode: pretrain + arch: ViT_small + patch_size: 16 + out_dim: 8192 + patch_out_dim: 8192 + norm_last_layer: False + shared_head: True + class_num: 1000 + drop_path: 0.1 + use_masked_im_modeling: True + norm_in_head: None + act_in_head: gelu + shared_head_teacher: True + global_crops_number: 2 + global_crops_scale: [ 0.25, 1.0 ] + local_crops_number: 10 + local_crops_scale: [ 0.05, 0.25 ] + pred_ratio: [ 0, 0.3 ] + pred_ratio_var: [ 0 0.2 ] + +# loss function config for traing/eval process +Loss: + Train: + - IBOTLoss: + weight: 1.0 + out_dim: 8192 + patch_out_dim: 8192 + # global_crops_number + ngcrops: 2 + # local_crops_number + nlcrops: 10 + warmup_teacher_temp: 0.04 + teacher_temp: 0.07 + # warmup_teacher_patch_temp + warmup_teacher_temp2: 0.04 + # teacher_patch_temp + teacher_temp2: 0.07 + warmup_teacher_temp_epochs: 30 + nepochs: 800 + lambda1: 1.0 + lambda2: 1.0 + mim_start_epoch: 0 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamWIBOT + weight_decay: 0.04 + no_weight_decay_name: norm bias + clip_norm: 0 + lr: + # for 8 cards + name: CosineIBOT + # 运算之后,注意修改 + base_value: 0.00125 + final_value: 1e-06 + epochs: 800 + step_each_epoch: 1 + warmup_epochs: 10 + start_warmup_value: 0 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: IBOTDataset + image_root: /data3/linkaihao/dataset/mini-imagenet-1k + cls_label_path: /data3/linkaihao/dataset/mini-imagenet-1k/train_list.txt + transform_ops: + - DecodeImage: + to_np: False + to_rgb: True + channel_first: False + backend: pil + - IBOTAugmentation: + global_crops_scale: [ 0.25, 1.0 ] + local_crops_scale: [ 0.05, 0.25 ] + global_crops_number: 2 + local_crops_number: 10 + sampler: + name: DistributedBatchSampler + batch_size: 32 + drop_last: True + shuffle: True + loader: + num_workers: 10 + use_shared_memory: True + +#Infer: +# infer_imgs: docs/images/inference_deployment/whl_demo.jpg +# batch_size: 10 +# transforms: +# - DecodeImage: +# to_rgb: True +# channel_first: False +# - ResizeImage: +# resize_short: 256 +# - CropImage: +# size: 224 +# - NormalizeImage: +# scale: 1.0/255.0 +# mean: [0.485, 0.456, 0.406] +# std: [0.229, 0.224, 0.225] +# order: '' +# - ToCHWImage: +# PostProcess: +# name: Topk +# topk: 5 +# class_id_map_file: ppcls/utils/imagenet1k_label_list.txt +# +#Metric: +# Train: +# - DistillationTopkAcc: +# model_key: "Student" +# topk: [1, 5] +# Eval: +# - DistillationTopkAcc: +# model_key: "Student" +# topk: [1, 5] \ No newline at end of file diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index df35eef640..a90be47f09 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -37,6 +37,7 @@ from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset from ppcls.data.dataloader.cifar import Cifar10, Cifar100 from ppcls.data.dataloader.metabin_sampler import DomainShuffleBatchSampler, NaiveIdentityBatchSampler +from ppcls.data.dataloader.ibot_dataset import IBOTDataset # sampler from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler diff --git a/ppcls/data/dataloader/__init__.py b/ppcls/data/dataloader/__init__.py index 391dcef65b..09999dfdc8 100644 --- a/ppcls/data/dataloader/__init__.py +++ b/ppcls/data/dataloader/__init__.py @@ -14,3 +14,4 @@ from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset from ppcls.data.dataloader.cifar import Cifar10, Cifar100 from ppcls.data.dataloader.metabin_sampler import DomainShuffleBatchSampler, NaiveIdentityBatchSampler +from ppcls.data.dataloader.ibot_dataset import IBOTDataset diff --git a/ppcls/data/dataloader/common_dataset.py b/ppcls/data/dataloader/common_dataset.py index 7530137eb1..c77a26431e 100644 --- a/ppcls/data/dataloader/common_dataset.py +++ b/ppcls/data/dataloader/common_dataset.py @@ -71,6 +71,10 @@ def __getitem__(self, idx): img = f.read() if self._transform_ops: img = transform(img, self._transform_ops) + + if isinstance(img, list): + return (img, self.labels[idx]) + img = img.transpose((2, 0, 1)) return (img, self.labels[idx]) diff --git a/ppcls/data/dataloader/ibot_dataset.py b/ppcls/data/dataloader/ibot_dataset.py new file mode 100644 index 0000000000..8cba4e8333 --- /dev/null +++ b/ppcls/data/dataloader/ibot_dataset.py @@ -0,0 +1,173 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import os +import random +import math +from .common_dataset import CommonDataset +# from paddle.vision.datasets import ImageFolder + +class IBOTDataset(CommonDataset): + """ImageNetDataset + Args: + image_root (str): image root, path to `ILSVRC2012` + cls_label_path (str): path to annotation file `train_list.txt` or `val_list.txt` + transform_ops (list, optional): list of transform op(s). Defaults to None. + delimiter (str, optional): delimiter. Defaults to None. + relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False. + """ + + def __init__(self, + image_root, + cls_label_path, + transform_ops=None, + delimiter=None, + relabel=False, + patch_size=16, + pred_ratio=0.3, + pred_ratio_var=0, + pred_aspect_ratio=(0.3, 1/0.3), + pred_shape='block', + pred_start_epoch=0 + ): + self.delimiter = delimiter if delimiter is not None else " " + self.relabel = relabel + super(IBOTDataset, self).__init__(image_root, cls_label_path, + transform_ops) + self.psz = patch_size + self.pred_ratio = pred_ratio[0] if isinstance(pred_ratio, list) and \ + len(pred_ratio) == 1 else pred_ratio + self.pred_ratio_var = pred_ratio_var[0] if isinstance(pred_ratio_var, list) and \ + len(pred_ratio_var) == 1 else pred_ratio_var + if isinstance(self.pred_ratio, list) and not isinstance(self.pred_ratio_var, list): + self.pred_ratio_var = [self.pred_ratio_var] * len(self.pred_ratio) + self.log_aspect_ratio = tuple(map(lambda x: math.log(x), pred_aspect_ratio)) + self.pred_shape = pred_shape + self.pred_start_epoch = pred_start_epoch + + def _load_anno(self, seed=None): + assert os.path.exists( + self._cls_path), f"path {self._cls_path} does not exist." + assert os.path.exists( + self._img_root), f"path {self._img_root} does not exist." + self.images = [] + self.labels = [] + + with open(self._cls_path) as fd: + lines = fd.readlines() + if self.relabel: + label_set = set() + for line in lines: + line = line.strip().split(self.delimiter) + label_set.add(np.int64(line[1])) + label_map = { + oldlabel: newlabel + for newlabel, oldlabel in enumerate(label_set) + } + + if seed is not None: + np.random.RandomState(seed).shuffle(lines) + for line in lines: + line = line.strip().split(self.delimiter) + self.images.append(os.path.join(self._img_root, line[0])) + if self.relabel: + self.labels.append(label_map[np.int64(line[1])]) + else: + self.labels.append(np.int64(line[1])) + assert os.path.exists(self.images[ + -1]), f"path {self.images[-1]} does not exist." + + def set_epoch(self, epoch): + self.epoch = epoch + + def get_pred_ratio(self): + if hasattr(self, 'epoch') and self.epoch < self.pred_start_epoch: + return 0 + + if isinstance(self.pred_ratio, list): + pred_ratio = [ + random.uniform(prm - prv, prm + prv) if prv > 0 and prm >= prv else prm + for prm, prv in zip(self.pred_ratio, self.pred_ratio_var) + ] + pred_ratio = random.choice(pred_ratio) + else: + assert self.pred_ratio >= self.pred_ratio_var + pred_ratio = random.uniform(self.pred_ratio - self.pred_ratio_var, self.pred_ratio + \ + self.pred_ratio_var) if self.pred_ratio_var > 0 else self.pred_ratio + + return pred_ratio + + def __getitem__(self, idx): + output = super(IBOTDataset, self).__getitem__(idx) + + masks = [] + for img in output[0]: + try: + H, W = img.shape[1] // self.psz, img.shape[2] // self.psz + except: + # skip non-image + continue + + high = self.get_pred_ratio() * H * W + if self.pred_shape == 'block': + # following BEiT (https://arxiv.org/abs/2106.08254), see at + # https://github.com/microsoft/unilm/blob/b94ec76c36f02fb2b0bf0dcb0b8554a2185173cd/beit/masking_generator.py#L55 + mask = np.zeros((H, W), dtype=bool) + mask_count = 0 + while mask_count < high: + max_mask_patches = high - mask_count + + delta = 0 + for attempt in range(10): + low = (min(H, W) // 3) ** 2 + target_area = random.uniform(low, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < W and h < H: + top = random.randint(0, H - h) + left = random.randint(0, W - w) + + num_masked = mask[top: top + h, left: left + w].sum() + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + + if delta == 0: + break + else: + mask_count += delta + + elif self.pred_shape == 'rand': + mask = np.hstack([ + np.zeros(H * W - int(high)), + np.ones(int(high)), + ]).astype(bool) + np.random.shuffle(mask) + mask = mask.reshape(H, W) + else: + raise ValueError("Invalid pred shape you input it.") + + masks.append(mask) + + return output + (masks,) \ No newline at end of file diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index 66234a44bd..7f27a77352 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -50,6 +50,7 @@ from ppcls.data.preprocess.ops.operators import PCALighting from .ops.operators import format_data from paddle.vision.transforms import Pad as Pad_paddle_vision +from ppcls.data.preprocess.ops.ibot_augment import IBOTAugmentation from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid diff --git a/ppcls/data/preprocess/ops/ibot_augment.py b/ppcls/data/preprocess/ops/ibot_augment.py new file mode 100644 index 0000000000..225dda2581 --- /dev/null +++ b/ppcls/data/preprocess/ops/ibot_augment.py @@ -0,0 +1,143 @@ +import random + +import paddle +from paddle.vision import BaseTransform, transforms +from PIL import ImageFilter, ImageOps + + +class GaussianBlur(object): + """ + Apply Gaussian Blur to the PIL image. + """ + + def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0): + self.prob = p + self.radius_min = radius_min + self.radius_max = radius_max + + def __call__(self, img): + do_it = random.random() <= self.prob + if not do_it: + return img + + return img.filter( + ImageFilter.GaussianBlur( + radius=random.uniform(self.radius_min, self.radius_max) + ) + ) + +class Solarization(object): + """ + Apply Solarization to the PIL image. + """ + + def __init__(self, p): + self.p = p + + def __call__(self, img): + if random.random() < self.p: + return ImageOps.solarize(img) + else: + return img + + +class RandomApply(BaseTransform): + def __init__(self, transforms: list, p=0.8): + super(RandomApply, self).__init__() + self.p = p + self.transforms = transforms + + def _apply_image(self, img): + if self.p < paddle.rand([1]): + return img + for t in self.transforms: + img = t(img) + return img + + +class RandomGrayscale(BaseTransform): + def __init__(self, prob=0.2): + super().__init__() + self.prob = prob + + def _apply_image(self, img): + if paddle.rand([1]) < self.prob: + nc = len(img.split()) + return transforms.to_grayscale(img, num_output_channels=nc) + return img + + +class IBOTAugmentation(object): + def __init__(self, + global_crops_scale, + local_crops_scale, + global_crops_number, + local_crops_number): + flip_and_color_jitter = transforms.Compose( + [ + transforms.RandomHorizontalFlip(prob=0.5), + RandomApply( + [ + transforms.ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 + ) + ], + p=0.8, + ), + RandomGrayscale(prob=0.2), + ] + ) + normalize = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ] + ) + + self.global_crops_number = global_crops_number + # first global crop + self.global_transfo1 = transforms.Compose( + [ + transforms.RandomResizedCrop( + 224, scale=global_crops_scale, interpolation="bicubic" + ), + flip_and_color_jitter, + GaussianBlur(1.0), + normalize, + ] + ) + + # second global crop + self.global_transfo2 = transforms.Compose( + [ + transforms.RandomResizedCrop( + 224, scale=global_crops_scale, interpolation="bicubic" + ), + flip_and_color_jitter, + GaussianBlur(0.1), + Solarization(0.2), + normalize, + ] + ) + # transformation for the local small crops + self.local_crops_number = local_crops_number + self.local_transfo = transforms.Compose( + [ + transforms.RandomResizedCrop( + 96, scale=local_crops_scale, interpolation="bicubic" + ), + flip_and_color_jitter, + GaussianBlur(p=0.5), + normalize, + ] + ) + + + def __call__(self, image): + crops = [] + crops.append(self.global_transfo1(image)) + for _ in range(self.global_crops_number - 1): + crops.append(self.global_transfo2(image)) + for _ in range(self.local_crops_number): + crops.append(self.local_transfo(image)) + return crops diff --git a/ppcls/engine/train/__init__.py b/ppcls/engine/train/__init__.py index 50bf9037f4..bb2bdeb19e 100644 --- a/ppcls/engine/train/__init__.py +++ b/ppcls/engine/train/__init__.py @@ -16,3 +16,4 @@ from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl from ppcls.engine.train.train_progressive import train_epoch_progressive from ppcls.engine.train.train_metabin import train_epoch_metabin +from ppcls.engine.train.train_ibot import train_epoch_ibot diff --git a/ppcls/engine/train/train_ibot.py b/ppcls/engine/train/train_ibot.py new file mode 100644 index 0000000000..ab669afecb --- /dev/null +++ b/ppcls/engine/train/train_ibot.py @@ -0,0 +1,135 @@ +import paddle +import math +import sys +from ppcls.engine.train.utils import update_loss, log_info +import time +import numpy as np +import paddle.distributed as dist + +def cancel_gradients_last_layer(epoch, model, freeze_last_layer): + if epoch >= freeze_last_layer: + return + for n, p in model.named_parameters(): + if "last_layer" in n: + # can not use `stop_gradient` + p.clear_grad() + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == epochs * niter_per_ep + return schedule + +def train_epoch_ibot(engine, epoch_id, print_batch_step): + tic = time.time() + lr_schedule = cosine_scheduler( + engine.config["Global"]["lr"] * engine.config["Global"]["batch_size"] * dist.get_world_size() / 256, + engine.config["Global"]["min_lr"], + engine.config["Global"]["epochs"],len(engine.train_dataloader), + warmup_epochs=engine.config["Global"]["warmup_epochs"] + ) + wd_schedule = cosine_scheduler( + engine.config["Global"]["weight_decay"], + engine.config["Global"]["weight_decay_end"], + engine.config["Global"]["epochs"],len(engine.train_dataloader) + ) + momentum_schedule = cosine_scheduler(engine.config["Global"]["momentum_teacher"], 1, + engine.config["Global"]["epochs"],len(engine.train_dataloader)) + + for iter_id, (images, labels,masks) in enumerate(engine.train_dataloader): + cur_iter_num = len(engine.train_dataloader) * epoch_id + iter_id + + if iter_id == 5: + for key in engine.time_info: + engine.time_info[key].reset() + engine.time_info["reader_cost"].update(time.time() - tic) + + # for i, param_group in enumerate(engine.optimizer[0]._param_groups): + # engine.optimizer[0].set_lr(lr_schedule[cur_iter_num]) #报错 + # if i == 0: + # # only the first group is regularized + # param_group["weight_decay"] = wd_schedule[cur_iter_num] #报错 + + if engine.amp: + amp_level = engine.config['AMP'].get("level", "O1").upper() + with paddle.amp.auto_cast(level=amp_level): + student_out, teacher_out, student_local_cls = engine.model.forward(images,masks) + loss = engine.train_loss_func.loss_func[0](student_out, teacher_out,student_local_cls, masks, epoch_id) + else: + student_out, teacher_out, student_local_cls = engine.model.forward(images,masks) + loss = engine.train_loss_func.loss_func[0](student_out, teacher_out,student_local_cls, masks, epoch_id) + + loss = loss["loss"] + if not math.isfinite(loss.item()): + print("Loss is {}, stopping training".format(loss.item()), force=True) + sys.exit(1) + + # log statistics + # probs1 = teacher_output[0].chunk(args.global_crops_number) + # probs2 = student_output[0].chunk(args.global_crops_number) + # + # if dist.is_initialized(): + # pred1 = utils.concat_all_gather(paddle.argmax(probs1[0], axis=1)) + # pred2 = utils.concat_all_gather(paddle.argmax(probs2[1], axis=1)) + # else: + # pred1 = paddle.argmax(probs1[0], axis=1) + # pred2 = paddle.argmax(probs2[1], axis=1) + # + # acc = ((pred1 == pred2).sum()) / pred1.shape[0] + # pred_labels.append(pred1) + # if dist.is_initialized(): + # real_labels.append(utils.concat_all_gather(labels.cuda())) + # else: + # real_labels.append(labels.cuda()) + + # clear grad + for i in range(len(engine.optimizer)): + engine.optimizer[i].clear_grad() + + # student update + if engine.amp: + engine.scaler.scale(loss).backward() + if engine.optimizer[0]._grad_clip is not None: + engine.scaler.unscale_(engine.optimizer[0]) + cancel_gradients_last_layer(epoch_id, engine.model.student, engine.config["Global"]["freeze_last_layer"]) + for i in range(len(engine.optimizer)): + engine.scaler.step(engine.optimizer[i]) + engine.scaler.update() + else: + loss.backward() + cancel_gradients_last_layer(epoch_id, engine.model.student, engine.config["Global"]["freeze_last_layer"]) + for i in range(len(engine.optimizer)): + engine.optimizer[i].step() + + # step lr(by step) + for i in range(len(engine.lr_sch)): + if not getattr(engine.lr_sch[i], "by_epoch", False): + engine.lr_sch[i].step() + + batch_size = engine.train_dataloader.batch_size + update_loss(engine, loss, batch_size) + engine.time_info["batch_cost"].update(time.time() - tic) + + if iter_id % print_batch_step == 0: + log_info(engine, batch_size, epoch_id, iter_id) + + # EMA update for the teacher + with paddle.no_grad(): + m = momentum_schedule[iter_id] + for param_stu, params_tea in zip(engine.model.student.parameters(), + engine.model.teacher_without_ddp.parameters()): + new_val = m * params_tea.numpy() + (1 - m) * param_stu.detach().numpy() + params_tea.set_value(new_val) + + tic = time.time() + paddle.device.cuda.synchronize() + engine.output_info['train_loss'] = loss.item() + engine.output_info['train_lr'] = engine.optimizer.get_lr() + engine.output_info['train_wd'] = engine.optimizer._param_groups[0]["weight_decay"] \ No newline at end of file diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index adf770dfd2..8c62b304c8 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -47,6 +47,8 @@ from .metabinloss import InterDomainShuffleLoss from .metabinloss import IntraDomainScatterLoss +from .ibotloss import IBOTLoss + class CombinedLoss(nn.Layer): def __init__(self, config_list): diff --git a/ppcls/loss/ibotloss.py b/ppcls/loss/ibotloss.py new file mode 100644 index 0000000000..6f40c7cede --- /dev/null +++ b/ppcls/loss/ibotloss.py @@ -0,0 +1,103 @@ +import numpy as np +from paddle import nn +import paddle +import paddle.nn.functional as F +import paddle.distributed as dist + +class IBOTLoss(nn.Layer): + def __init__(self, out_dim=8192, patch_out_dim=8192, ngcrops=2, nlcrops=0, warmup_teacher_temp=0.04, + teacher_temp=0.04, warmup_teacher_temp2=0.04, teacher_temp2=0.07, + warmup_teacher_temp_epochs=30, nepochs=200, student_temp=0.1, + center_momentum=0.9, center_momentum2=0.9, + lambda1=1.0, lambda2=1.0, mim_start_epoch=0): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.center_momentum2 = center_momentum2 + self.ngcrops = ngcrops + self.nlcrops = nlcrops + self.ncrops = ngcrops + nlcrops + self.register_buffer("center", paddle.zeros((1, out_dim))) + self.register_buffer("center2", paddle.zeros((1, 1, patch_out_dim))) + self.lambda1 = lambda1 + self.lambda2 = lambda2 + + # we apply a warm up for the teacher temperature because + # a too high temperature makes the training instable at the beginning + self.teacher_temp_schedule = np.concatenate(( + np.linspace(warmup_teacher_temp, + teacher_temp, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp + )) + self.teacher_temp2_schedule = np.concatenate(( + np.linspace(warmup_teacher_temp2, + teacher_temp2, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp2 + )) if mim_start_epoch == 0 else np.concatenate(( + np.ones(mim_start_epoch) * warmup_teacher_temp2, + np.linspace(warmup_teacher_temp2, + teacher_temp2, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs - mim_start_epoch) * teacher_temp2 + )) + + def forward(self, student_output, teacher_output, student_local_cls, student_mask, epoch): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + """ + student_cls, student_patch = student_output + teacher_cls, teacher_patch = teacher_output + + if student_local_cls is not None: + student_cls = paddle.concat([student_cls, student_local_cls]) + + # [CLS] and patch for global patches + student_cls = student_cls / self.student_temp + student_cls_c = student_cls.chunk(self.ncrops) + student_patch = student_patch / self.student_temp + student_patch_c = student_patch.chunk(self.ngcrops) + + # teacher centering and sharpening + temp = self.teacher_temp_schedule[epoch] + temp2 = self.teacher_temp2_schedule[epoch] + teacher_cls_c = F.softmax((teacher_cls - self.center) / temp, axis=-1) + teacher_cls_c = teacher_cls_c.detach().chunk(self.ngcrops) + teacher_patch_c = F.softmax((teacher_patch - self.center2) / temp2, axis=-1) + teacher_patch_c = teacher_patch_c.detach().chunk(self.ngcrops) + + total_loss1, n_loss_terms1 = 0, 0 + total_loss2, n_loss_terms2 = 0, 0 + for q in range(len(teacher_cls_c)): + for v in range(len(student_cls_c)): + if v == q: + loss2 = paddle.sum(-teacher_patch_c[q] * F.log_softmax(student_patch_c[v], axis=-1), axis=-1) + mask = paddle.flatten(student_mask[v].astype('float64'),-2,-1) + loss2 = paddle.sum(loss2 * mask, axis=-1) / mask.sum(axis=-1).clip(min=1.0) + total_loss2 += loss2.mean() + n_loss_terms2 += 1 + else: + loss1 = paddle.sum(-teacher_cls_c[q] * F.log_softmax(student_cls_c[v], axis=-1), axis=-1) + total_loss1 += loss1.mean() + n_loss_terms1 += 1 + + total_loss1 = total_loss1 / n_loss_terms1 * self.lambda1 + total_loss2 = total_loss2 / n_loss_terms2 * self.lambda2 + total_loss = dict(cls=total_loss1, patch=total_loss2, loss=total_loss1 + total_loss2) + self.update_center(teacher_cls, teacher_patch) + return total_loss + + @paddle.no_grad() + def update_center(self, teacher_cls, teacher_patch): + """ + Update center used for teacher output. + """ + cls_center = paddle.sum(teacher_cls, axis=0, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(cls_center) + cls_center = cls_center / (len(teacher_cls) * dist.get_world_size()) + self.center = self.center * self.center_momentum + cls_center * (1 - self.center_momentum) + + patch_center = paddle.sum(teacher_patch.mean(1), axis=0, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(patch_center) + patch_center = patch_center / (len(teacher_patch) * dist.get_world_size()) + self.center2 = self.center2 * self.center_momentum2 + patch_center * (1 - self.center_momentum2) \ No newline at end of file diff --git a/ppcls/optimizer/learning_rate.py b/ppcls/optimizer/learning_rate.py index f1d9074e45..15ac6a2249 100644 --- a/ppcls/optimizer/learning_rate.py +++ b/ppcls/optimizer/learning_rate.py @@ -20,7 +20,7 @@ from typing import Union from paddle.optimizer import lr from ppcls.utils import logger - +import numpy as np class LRBase(object): """Base class for custom learning rates @@ -600,3 +600,36 @@ def _lr_lambda(current_step): last_epoch=self.last_epoch) setattr(learning_rate, "by_epoch", self.by_epoch) return learning_rate + +class CosineIBOT(LRBase): + def __init__(self, + base_value, + final_value, + epochs, + step_each_epoch, + warmup_epochs=0, + start_warmup_value=0): + self.base_value = base_value + + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * step_each_epoch + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * step_each_epoch - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + + self.schedule = np.concatenate((warmup_schedule, schedule)) + assert len(self.schedule) == epochs * step_each_epoch + + def __call__(self): + def _lr_lambda(current_step): + return self.schedule[current_step] + + learning_rate = lr.LambdaDecay( + learning_rate=self.base_value, + lr_lambda=_lr_lambda, + last_epoch=-1 + ) + setattr(learning_rate, "by_epoch", False) + return learning_rate \ No newline at end of file diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index f3c3d354b8..0e832bdf39 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -21,7 +21,7 @@ from paddle import optimizer as optim from ppcls.utils import logger from functools import partial - +import paddle.regularizer class SGD(object): """ @@ -429,3 +429,84 @@ def __call__(self, model_list): optimizer = self.AdamWDLImpl(**opt_args) return optimizer + +class AdamWIBOT(optim.AdamW): + """ + This optimizer can dynamically adjust the coefficient of weight_decay. + """ + def __init__(self, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + weight_decay=None, + multi_precision=False, + grad_clip=None, + no_weight_decay_name=None, + one_dim_param_no_weight_decay=False, + **args): + + self.learning_rate = learning_rate + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.grad_clip = grad_clip + if isinstance(weight_decay, str): + weight_decay = [float(it) for it in weight_decay.split()] + from .learning_rate import Cosine + self.weight_decay = Cosine( + self.learning_rate.epochs, self.learning_rate.step_each_epoch, + self.learning_rate.learning_rate, self.learning_rate.eta_min + ) + else: + self.weight_decay = weight_decay + + self.weight_decay = weight_decay + self.multi_precision = multi_precision + self.no_weight_decay_name_list = no_weight_decay_name.split( + ) if no_weight_decay_name else [] + self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay + + def __call__(self, model_list): + # model_list is None in static graph + parameters = sum([m.parameters() for m in model_list], + []) if model_list else None + + if model_list is None: + if self.one_dim_param_no_weight_decay or len( + self.no_weight_decay_name_list) != 0: + msg = "\"AdamW\" does not support setting \"no_weight_decay\" in static graph. Please use dynamic graph." + logger.error(Exception(msg)) + raise Exception(msg) + + self.no_weight_decay_param_name_list = [ + p.name for model in model_list for n, p in model.named_parameters() + if any(nd in n for nd in self.no_weight_decay_name_list) + ] if model_list else [] + + if self.one_dim_param_no_weight_decay: + self.no_weight_decay_param_name_list += [ + p.name + for model in model_list for n, p in model.named_parameters() + if len(p.shape) == 1 + ] if model_list else [] + + super().__init__( + learning_rate=self.learning_rate, + beta1=self.beta1, + beta2=self.beta2, + epsilon=self.epsilon, + parameters=parameters, + weight_decay=self.weight_decay if isinstance(self.weight_decay, float) else self.weight_decay.get_lr(), + multi_precision=self.multi_precision, + grad_clip=self.grad_clip, + apply_decay_param_fun=self._apply_decay_param_fun) + + return self + + def _apply_decay_param_fun(self, name): + return name not in self.no_weight_decay_param_name_list + + def wd_step(self): + self.weight_decay.step() + self.regularization = paddle.regularizer.L2Decay(self.weight_decay.get_lr()) \ No newline at end of file