diff --git a/tests/test_models.py b/tests/test_models.py index b7bd143105..e275e43028 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,7 +17,7 @@ # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*'] + 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'perceiver*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures @@ -26,7 +26,7 @@ EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', - '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*'] + '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'perceiver_l*'] else: EXCLUDE_FILTERS = [] @@ -218,11 +218,12 @@ def test_model_default_cfgs_non_std(model_name, batch_size): # check first conv(s) names match default_cfg first_conv = cfg['first_conv'] - if isinstance(first_conv, str): - first_conv = (first_conv,) - assert isinstance(first_conv, (tuple, list)) - for fc in first_conv: - assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params' + if first_conv is not None: + if isinstance(first_conv, str): + first_conv = (first_conv,) + assert isinstance(first_conv, (tuple, list)) + for fc in first_conv: + assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params' if 'GITHUB_ACTIONS' not in os.environ: diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 56c812d4b8..884435db16 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -22,6 +22,7 @@ from .nasnet import * from .nest import * from .nfnet import * +from .perceiver import * from .pit import * from .pnasnet import * from .regnet import * @@ -36,6 +37,7 @@ from .swin_transformer import * from .tnt import * from .tresnet import * +from .twins import * from .vgg import * from .visformer import * from .vision_transformer import * @@ -44,7 +46,6 @@ from .xception import * from .xception_aligned import * from .xcit import * -from .twins import * from .factory import create_model, split_model_name, safe_model_name from .helpers import load_checkpoint, resume_checkpoint, model_parameters diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 4e0f2b2111..8712a62ce0 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -8,12 +8,6 @@ Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets -Status: -* These models are a work in progress, experiments ongoing. -* Pretrained weights for two models so far, more to come. -* Model details updated to closer match official JAX code now that it's released -* NF-ResNet, NF-RegNet-B, and NFNet-F models supported - Hacked together by / copyright Ross Wightman, 2021. """ import math diff --git a/timm/models/perceiver.py b/timm/models/perceiver.py new file mode 100644 index 0000000000..76a153e5b8 --- /dev/null +++ b/timm/models/perceiver.py @@ -0,0 +1,493 @@ +""" Perceiver + +Paper: `Perceiver: General Perception with Iterative Attention` - https://arxiv.org/abs/2103.03206 + +Official Deepmind code: TBD (doesn't exist yet) + +Fourier feature position embedding references: + * Official NeRF impl - https://github.com/bmild/nerf + * Lucidrain's Perceiver impl - https://github.com/lucidrains/perceiver-pytorch + +Status: +* Work in progress, currently running training trials with S and M models (rather slow) + +Hacked together by / copyright Ross Wightman, 2021. +""" +import math +from functools import partial +from typing import List, Tuple + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, named_apply +from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': None, 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (weights from official Google JAX impl) + 'perceiver_ss': _cfg( + url='', input_size=(3, 192, 192)), + 'perceiver_s': _cfg( + url='', input_size=(3, 192, 192)), + 'perceiver_m': _cfg( + url=''), + 'perceiver_m_ls': _cfg( + url=''), + 'perceiver_l': _cfg( + url=''), +} + + +def fourier_encode(x, max_freq_log2: int = 8, num_bands: int = 64): + """ Fourier feature embedding. + Referenced official NeRF code and Lucidrain's PyTorch Perceiver impl. + """ + # FIXME this will likely need to change once official code / weights are available + x = x.unsqueeze(-1) + bands = 2 ** torch.linspace(0, max_freq_log2 - 1, num_bands, device=x.device, dtype=x.dtype) + x_bands = x * math.pi * bands + x = torch.cat([x, x_bands.sin(), x_bands.cos()], dim=-1) + return x + + +def fourier_grid( + shape: List[int], max_freq_log2: int = 8, num_bands: int = 64, device: torch.device = torch.device('cuda')): + grid = torch.stack(torch.meshgrid([torch.linspace(-1., 1., steps=s, device=device) for s in shape]), dim=-1) + enc_pos = fourier_encode(grid, max_freq_log2, num_bands) + return enc_pos.transpose(-1, -2).flatten(len(shape)) + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttention(nn.Module): + """ + """ + def __init__(self, latent_dim, data_dim, attn_dim=None, num_heads=1, qkv_bias=True, proj_drop=0.): + super().__init__() + assert latent_dim % num_heads == 0, f"dim {latent_dim} should be divided by num_heads {num_heads}." + + self.latent_dim = latent_dim + self.attn_dim = attn_dim or min(latent_dim, data_dim) + self.num_heads = num_heads + head_dim = self.attn_dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(latent_dim, self.attn_dim, bias=qkv_bias) + self.kv = nn.Linear(data_dim, self.attn_dim * 2, bias=qkv_bias) + self.proj = nn.Linear(self.attn_dim, latent_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, latent, data): + B = latent.shape[0] + q = self.q(latent).reshape(B, -1, self.num_heads, self.attn_dim // self.num_heads).permute(0, 2, 1, 3) + + kv = self.kv(data).reshape(B, -1, 2, self.num_heads, self.attn_dim // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(1, 2).reshape(B, -1, self.attn_dim) + out = self.proj(out) + out = self.proj_drop(out) + return out + + +class Affine(nn.Module): + def __init__(self, dim): + super().__init__() + self.alpha = nn.Parameter(torch.ones((1, 1, dim))) + self.beta = nn.Parameter(torch.zeros((1, 1, dim))) + + def forward(self, x): + return torch.addcmul(self.beta, self.alpha, x) + + +@torch.jit.interface +class CrossInterface(torch.nn.Module): + def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor: + pass + + +class CrossBlock(nn.Module): + + def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True, + drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1_latent = norm_layer(latent_dim) + self.norm1_data = norm_layer(data_dim) + self.attn = attn_layer( + latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(latent_dim) + mlp_hidden_dim = int(latent_dim * mlp_ratio) + self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor: + latent = latent + self.drop_path(self.attn( + self.norm1_latent(latent), + self.norm1_data(data), + )) + latent = latent + self.drop_path(self.mlp(self.norm2(latent))) + return latent + + +class CrossBlockLayerScale(nn.Module): + + def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True, init_values=1e-5, + drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1_latent = norm_layer(latent_dim) + self.norm1_data = norm_layer(data_dim) + self.attn = attn_layer( + latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(latent_dim) + mlp_hidden_dim = int(latent_dim * mlp_ratio) + self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.ls1 = nn.Parameter(init_values * torch.ones(latent_dim)) + self.ls2 = nn.Parameter(init_values * torch.ones(latent_dim)) + + def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor: + latent = latent + self.drop_path(self.ls1 * self.attn( + self.norm1_latent(latent), + self.norm1_data(data), + )) + latent = latent + self.drop_path(self.ls2 * self.mlp(self.norm2(latent))) + return latent + + +@torch.jit.interface +class TransformerInterface(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + pass + + +class TransformerBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0., + drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class TransformerBlockLayerScale(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0., + drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.ls1 = nn.Parameter(init_values * torch.ones(dim)) + self.ls2 = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + x = x + self.drop_path(self.ls1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.ls2 * self.mlp(self.norm2(x))) + return x + + +class TransformerStack(nn.Module): + """ A stack-o-transformers + NOTE this could have been a simple nn.Sequential but needed to wrap in module to use Interface + def for ModuleDict torchscript compat. + """ + def __init__(self, depth, dim, num_heads, mlp_ratio=4., block=None, **kwargs): + super().__init__() + block = block or TransformerBlock + self.stack = nn.Sequential(*[ + block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs) + for _ in range(depth) + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.stack(x) + + +def get_layer_layout(cross_depths, num_stages=8, share_weights=None): + if isinstance(cross_depths, (tuple, list)): + stage_cross_depths = tuple(cross_depths) + stage_cross_depths = (stage_cross_depths + (0,) * num_stages)[:num_stages] + else: + stage_cross_depths = to_ntuple(num_stages)(cross_depths) + prev_cross_key = '' + prev_transformer_key = '' + keys = [] + num_cross = 0 + num_transformer = 0 + for i, cd in enumerate(stage_cross_depths): + for j in range(cd): + key = prev_cross_key + if share_weights is None or num_cross <= share_weights[0]: + key = f'c{i}_{j}' + keys += [key] + prev_cross_key = key + num_cross += 1 + key = prev_transformer_key + if share_weights is None or num_transformer <= share_weights[1]: + key = f't{i}' + keys += [key] + prev_transformer_key = key + num_transformer += 1 + return keys + + +class Perceiver(nn.Module): + """ Perceiver + + Paper: `Perceiver: General Perception with Iterative Attention` - https://arxiv.org/abs/2103.03206 + """ + + def __init__( + self, in_chans=3, num_classes=1000, num_stages=8, cross_depths=(1,), transformer_depth=6, + latent_dim=1024, num_latents=512, num_latent_heads=8, latent_mlp_ratio=1.0, + cross_attn_dim=None, num_cross_heads=1, cross_mlp_ratio=1.0, share_weights=(1, 0), + pos_embed_type='fourier', pos_embed_dim=128, data_bands=64, data_ndim=2, data_max_freq=10, + data_spatial=False, qkv_bias=True, cross_block=None, transformer_block=None, + cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None, + drop_rate=0., drop_path_rate=0., weight_init=''): + """ + Args: + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + num_stages (int): number of stages (cross + transformer stack repeats) + num_cross_heads (int): number of cross-attention heads + cross_mlp_ratio (flaot): ratio of mlp hidden dim to embedding dim + share_weights (Optiona[Tuple]): starting index of latent and transformer share (or None for no share) + latent_dim (int): + num_latents (int): + num_latent_heads (int): number of latent-attention heads + latent_mlp_ratio (float): + qkv_bias (bool): enable bias for qkv if True + pos_embed_type (str): type of pos embed (TODO: currently defaults to fourier) + pos_embed_dim (int): embedding dimension (for other pos-embed options besides fourier) + data_bands (int): + data_ndim (int): + data_max_freq (int): + drop_rate (float): dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.latent_dim = latent_dim + cross_block = cross_block or CrossBlock + transformer_block = transformer_block or TransformerBlock + cross_attn_layer = cross_attn_layer or CrossAttention + attn_layer = attn_layer or Attention + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.latents = nn.Parameter(torch.zeros(num_latents, latent_dim)) + self.data_bands = data_bands + self.data_max_freq = data_max_freq + self.data_ndim = data_ndim + self.data_dim = self.data_ndim * (2 * self.data_bands + 1) + in_chans + self.data_spatial = data_spatial + + self.blocks_cross = nn.ModuleDict() + self.blocks_trans = nn.ModuleDict() + self.layer_keys = get_layer_layout(cross_depths, num_stages, share_weights) + for i, k in enumerate(self.layer_keys): + stage_args = dict( + qkv_bias=qkv_bias, drop=drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, act_layer=act_layer) + if k.startswith('c'): + self.blocks_cross[k] = cross_block( + latent_dim=latent_dim, data_dim=self.data_dim, attn_dim=cross_attn_dim, num_heads=num_cross_heads, + mlp_ratio=cross_mlp_ratio, attn_layer=cross_attn_layer, **stage_args) + else: + self.blocks_trans[k] = TransformerStack( + depth=transformer_depth, dim=latent_dim, num_heads=num_latent_heads, + mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, block=transformer_block, **stage_args) + + self.norm = norm_layer(latent_dim) + self.head = nn.Linear(latent_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.latents, std=.02) + named_apply(partial(_init_weights, head_bias=head_bias), self) + + @torch.jit.ignore + def no_weight_decay(self): + return {'latents'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.latent_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B, C, H, W = x.shape + # FIXME cache fourier embedding and implement positional options + # FIXME support ndim inputs, don't assume 2D? + data = fourier_grid(x.shape[2:], max_freq_log2=self.data_max_freq, num_bands=self.data_bands, device=x.device) + if self.data_spatial: + data = torch.cat([x, data.unsqueeze(0).expand(B, -1, -1, -1).permute(0, 3, 1, 2)], dim=1) + else: + data = torch.cat([x.permute(0, 2, 3, 1), data.unsqueeze(0).expand(B, -1, -1, -1)], dim=-1) + data = data.reshape(B, H * W, -1) + x = self.latents.unsqueeze(0).expand(B, -1, -1) + for k in self.layer_keys: + if k.startswith('c'): + cross_blocks: CrossInterface = self.blocks_cross[k] # interface annotation for torchscript sillyness + x = cross_blocks.forward(x, data) + else: + transformer: TransformerInterface = self.blocks_trans[k] + x = transformer.forward(x) + x = self.norm(x) + x = x.mean(dim=1) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _init_weights(module: nn.Module, name: str = '', head_bias: float = 0.): + """ weight initialization + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +def _create_perceiver(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg( + Perceiver, variant, pretrained, + default_cfg=default_cfg, + **kwargs) + return model + + +@register_model +def perceiver_ss(pretrained=False, **kwargs): + """ Perceiver-Small (Shared) + One initial cross attn and all transformer stacks shared. ~11M params + """ + model_kwargs = dict( + cross_depths=(1,), latent_dim=512, num_latents=256, cross_attn_dim=128, data_bands=36, **kwargs) + model = _create_perceiver('perceiver_ss', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def perceiver_s(pretrained=False, **kwargs): + """ Perceiver-Small + One initial cross attn and all but first transformer stacks shared. ~20M params + """ + model_kwargs = dict( + cross_depths=(1,), latent_dim=512, num_latents=256, cross_attn_dim=128, data_bands=36, + share_weights=(1, 1), **kwargs) + model = _create_perceiver('perceiver_s', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def perceiver_m(pretrained=False, **kwargs): + """ Perceiver-Medium + Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params. + """ + model_kwargs = dict( + cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40, **kwargs) + model = _create_perceiver('perceiver_m', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def perceiver_m_ls(pretrained=False, **kwargs): + """ Perceiver-Medium w/ LayerScale + Affine + Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params. + LayerScale + Affine influenced by CaiT, LeViT, ResMLP from Facebook AI + """ + model_kwargs = dict( + cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40, + transformer_block=TransformerBlockLayerScale, cross_block=CrossBlockLayerScale, + norm_layer=Affine, **kwargs) + model = _create_perceiver('perceiver_m_ls', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def perceiver_l(pretrained=False, **kwargs): + """ Perceiver-Large + One cross attn per 8 transformer stacks. All but first cross attn shared, all transformer stacks shared. + This variant is closest to the paper model for reported ImageNet results. ~45M params. + """ + model_kwargs = dict(cross_depths=1, latent_dim=1024, num_latents=512, **kwargs) + model = _create_perceiver('perceiver_l', pretrained=pretrained, **model_kwargs) + return model \ No newline at end of file