Skip to content

Commit acfd85a

Browse files
committed
All swin models support spatial output, add output_fmt to v1/v2 and use ClassifierHead.
* update ClassifierHead to allow different input format * add output format support to patch embed * fix some flatten issues for a few conv head models * add Format enum and helpers for tensor format (layout) choices
1 parent c30a160 commit acfd85a

20 files changed

+1419
-1000
lines changed

tests/test_models.py

Lines changed: 162 additions & 151 deletions
Large diffs are not rendered by default.

timm/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
2121
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
2222
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
23+
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
2324
from .gather_excite import GatherExcite
2425
from .global_context import GlobalContext
2526
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple

timm/layers/adaptive_avgmax_pool.py

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,37 @@
99
1010
Hacked together by / Copyright 2020 Ross Wightman
1111
"""
12+
from typing import Optional, Tuple, Union
13+
1214
import torch
1315
import torch.nn as nn
1416
import torch.nn.functional as F
1517

18+
from .format import get_spatial_dim, get_channel_dim
19+
20+
_int_tuple_2_t = Union[int, Tuple[int, int]]
21+
1622

1723
def adaptive_pool_feat_mult(pool_type='avg'):
18-
if pool_type == 'catavgmax':
24+
if pool_type.endswith('catavgmax'):
1925
return 2
2026
else:
2127
return 1
2228

2329

24-
def adaptive_avgmax_pool2d(x, output_size=1):
30+
def adaptive_avgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
2531
x_avg = F.adaptive_avg_pool2d(x, output_size)
2632
x_max = F.adaptive_max_pool2d(x, output_size)
2733
return 0.5 * (x_avg + x_max)
2834

2935

30-
def adaptive_catavgmax_pool2d(x, output_size=1):
36+
def adaptive_catavgmax_pool2d(x, output_size: _int_tuple_2_t = 1):
3137
x_avg = F.adaptive_avg_pool2d(x, output_size)
3238
x_max = F.adaptive_max_pool2d(x, output_size)
3339
return torch.cat((x_avg, x_max), 1)
3440

3541

36-
def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
42+
def select_adaptive_pool2d(x, pool_type='avg', output_size: _int_tuple_2_t = 1):
3743
"""Selectable global pooling function with dynamic input kernel size
3844
"""
3945
if pool_type == 'avg':
@@ -49,17 +55,56 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
4955
return x
5056

5157

52-
class FastAdaptiveAvgPool2d(nn.Module):
53-
def __init__(self, flatten=False):
54-
super(FastAdaptiveAvgPool2d, self).__init__()
58+
class FastAdaptiveAvgPool(nn.Module):
59+
def __init__(self, flatten: bool = False, input_fmt: F = 'NCHW'):
60+
super(FastAdaptiveAvgPool, self).__init__()
61+
self.flatten = flatten
62+
self.dim = get_spatial_dim(input_fmt)
63+
64+
def forward(self, x):
65+
return x.mean(self.dim, keepdim=not self.flatten)
66+
67+
68+
class FastAdaptiveMaxPool(nn.Module):
69+
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
70+
super(FastAdaptiveMaxPool, self).__init__()
5571
self.flatten = flatten
72+
self.dim = get_spatial_dim(input_fmt)
73+
74+
def forward(self, x):
75+
return x.amax(self.dim, keepdim=not self.flatten)
76+
77+
78+
class FastAdaptiveAvgMaxPool(nn.Module):
79+
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
80+
super(FastAdaptiveAvgMaxPool, self).__init__()
81+
self.flatten = flatten
82+
self.dim = get_spatial_dim(input_fmt)
83+
84+
def forward(self, x):
85+
x_avg = x.mean(self.dim, keepdim=not self.flatten)
86+
x_max = x.amax(self.dim, keepdim=not self.flatten)
87+
return 0.5 * x_avg + 0.5 * x_max
88+
89+
90+
class FastAdaptiveCatAvgMaxPool(nn.Module):
91+
def __init__(self, flatten: bool = False, input_fmt: str = 'NCHW'):
92+
super(FastAdaptiveCatAvgMaxPool, self).__init__()
93+
self.flatten = flatten
94+
self.dim_reduce = get_spatial_dim(input_fmt)
95+
if flatten:
96+
self.dim_cat = 1
97+
else:
98+
self.dim_cat = get_channel_dim(input_fmt)
5699

57100
def forward(self, x):
58-
return x.mean((2, 3), keepdim=not self.flatten)
101+
x_avg = x.mean(self.dim_reduce, keepdim=not self.flatten)
102+
x_max = x.amax(self.dim_reduce, keepdim=not self.flatten)
103+
return torch.cat((x_avg, x_max), self.dim_cat)
59104

60105

61106
class AdaptiveAvgMaxPool2d(nn.Module):
62-
def __init__(self, output_size=1):
107+
def __init__(self, output_size: _int_tuple_2_t = 1):
63108
super(AdaptiveAvgMaxPool2d, self).__init__()
64109
self.output_size = output_size
65110

@@ -68,7 +113,7 @@ def forward(self, x):
68113

69114

70115
class AdaptiveCatAvgMaxPool2d(nn.Module):
71-
def __init__(self, output_size=1):
116+
def __init__(self, output_size: _int_tuple_2_t = 1):
72117
super(AdaptiveCatAvgMaxPool2d, self).__init__()
73118
self.output_size = output_size
74119

@@ -79,26 +124,41 @@ def forward(self, x):
79124
class SelectAdaptivePool2d(nn.Module):
80125
"""Selectable global pooling layer with dynamic input kernel size
81126
"""
82-
def __init__(self, output_size=1, pool_type='fast', flatten=False):
127+
def __init__(
128+
self,
129+
output_size: _int_tuple_2_t = 1,
130+
pool_type: str = 'fast',
131+
flatten: bool = False,
132+
input_fmt: str = 'NCHW',
133+
):
83134
super(SelectAdaptivePool2d, self).__init__()
135+
assert input_fmt in ('NCHW', 'NHWC')
84136
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
85-
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
86-
if pool_type == '':
137+
if not pool_type:
87138
self.pool = nn.Identity() # pass through
88-
elif pool_type == 'fast':
89-
assert output_size == 1
90-
self.pool = FastAdaptiveAvgPool2d(flatten)
139+
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
140+
elif pool_type.startswith('fast') or input_fmt != 'NCHW':
141+
assert output_size == 1, 'Fast pooling and non NCHW input formats require output_size == 1.'
142+
if pool_type.endswith('avgmax'):
143+
self.pool = FastAdaptiveAvgMaxPool(flatten, input_fmt=input_fmt)
144+
elif pool_type.endswith('catavgmax'):
145+
self.pool = FastAdaptiveCatAvgMaxPool(flatten, input_fmt=input_fmt)
146+
elif pool_type.endswith('max'):
147+
self.pool = FastAdaptiveMaxPool(flatten, input_fmt=input_fmt)
148+
else:
149+
self.pool = FastAdaptiveAvgPool(flatten, input_fmt=input_fmt)
91150
self.flatten = nn.Identity()
92-
elif pool_type == 'avg':
93-
self.pool = nn.AdaptiveAvgPool2d(output_size)
94-
elif pool_type == 'avgmax':
95-
self.pool = AdaptiveAvgMaxPool2d(output_size)
96-
elif pool_type == 'catavgmax':
97-
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
98-
elif pool_type == 'max':
99-
self.pool = nn.AdaptiveMaxPool2d(output_size)
100151
else:
101-
assert False, 'Invalid pool type: %s' % pool_type
152+
assert input_fmt == 'NCHW'
153+
if pool_type == 'avgmax':
154+
self.pool = AdaptiveAvgMaxPool2d(output_size)
155+
elif pool_type == 'catavgmax':
156+
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
157+
elif pool_type == 'max':
158+
self.pool = nn.AdaptiveMaxPool2d(output_size)
159+
else:
160+
self.pool = nn.AdaptiveAvgPool2d(output_size)
161+
self.flatten = nn.Flatten(1) if flatten else nn.Identity()
102162

103163
def is_identity(self):
104164
return not self.pool_type

timm/layers/classifier.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,23 @@
1515
from .create_norm import get_norm_layer
1616

1717

18-
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
18+
def _create_pool(
19+
num_features: int,
20+
num_classes: int,
21+
pool_type: str = 'avg',
22+
use_conv: bool = False,
23+
input_fmt: Optional[str] = None,
24+
):
1925
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
2026
if not pool_type:
2127
assert num_classes == 0 or use_conv,\
2228
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
2329
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
24-
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
30+
global_pool = SelectAdaptivePool2d(
31+
pool_type=pool_type,
32+
flatten=flatten_in_pool,
33+
input_fmt=input_fmt,
34+
)
2535
num_pooled_features = num_features * global_pool.feat_mult()
2636
return global_pool, num_pooled_features
2737

@@ -36,9 +46,25 @@ def _create_fc(num_features, num_classes, use_conv=False):
3646
return fc
3747

3848

39-
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
40-
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
41-
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
49+
def create_classifier(
50+
num_features: int,
51+
num_classes: int,
52+
pool_type: str = 'avg',
53+
use_conv: bool = False,
54+
input_fmt: str = 'NCHW',
55+
):
56+
global_pool, num_pooled_features = _create_pool(
57+
num_features,
58+
num_classes,
59+
pool_type,
60+
use_conv=use_conv,
61+
input_fmt=input_fmt,
62+
)
63+
fc = _create_fc(
64+
num_pooled_features,
65+
num_classes,
66+
use_conv=use_conv,
67+
)
4268
return global_pool, fc
4369

4470

@@ -52,6 +78,7 @@ def __init__(
5278
pool_type: str = 'avg',
5379
drop_rate: float = 0.,
5480
use_conv: bool = False,
81+
input_fmt: str = 'NCHW',
5582
):
5683
"""
5784
Args:
@@ -64,28 +91,43 @@ def __init__(
6491
self.drop_rate = drop_rate
6592
self.in_features = in_features
6693
self.use_conv = use_conv
67-
68-
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
69-
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
94+
self.input_fmt = input_fmt
95+
96+
self.global_pool, self.fc = create_classifier(
97+
in_features,
98+
num_classes,
99+
pool_type,
100+
use_conv=use_conv,
101+
input_fmt=input_fmt,
102+
)
70103
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
71104

72-
def reset(self, num_classes, global_pool=None):
73-
if global_pool is not None:
74-
if global_pool != self.global_pool.pool_type:
75-
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
76-
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
77-
num_pooled_features = self.in_features * self.global_pool.feat_mult()
78-
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
105+
def reset(self, num_classes, pool_type=None):
106+
if pool_type is not None and pool_type != self.global_pool.pool_type:
107+
self.global_pool, self.fc = create_classifier(
108+
self.in_features,
109+
num_classes,
110+
pool_type=pool_type,
111+
use_conv=self.use_conv,
112+
input_fmt=self.input_fmt,
113+
)
114+
self.flatten = nn.Flatten(1) if self.use_conv and pool_type else nn.Identity()
115+
else:
116+
num_pooled_features = self.in_features * self.global_pool.feat_mult()
117+
self.fc = _create_fc(
118+
num_pooled_features,
119+
num_classes,
120+
use_conv=self.use_conv,
121+
)
79122

80123
def forward(self, x, pre_logits: bool = False):
81124
x = self.global_pool(x)
82125
if self.drop_rate:
83126
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
84127
if pre_logits:
85-
return x.flatten(1)
86-
else:
87-
x = self.fc(x)
88128
return self.flatten(x)
129+
x = self.fc(x)
130+
return self.flatten(x)
89131

90132

91133
class NormMlpClassifierHead(nn.Module):

timm/layers/format.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from enum import Enum
2+
from typing import Union
3+
4+
import torch
5+
6+
7+
class Format(str, Enum):
8+
NCHW = 'NCHW'
9+
NHWC = 'NHWC'
10+
NCL = 'NCL'
11+
NLC = 'NLC'
12+
13+
14+
FormatT = Union[str, Format]
15+
16+
17+
def get_spatial_dim(fmt: FormatT):
18+
fmt = Format(fmt)
19+
if fmt is Format.NLC:
20+
dim = (1,)
21+
elif fmt is Format.NCL:
22+
dim = (2,)
23+
elif fmt is Format.NHWC:
24+
dim = (1, 2)
25+
else:
26+
dim = (2, 3)
27+
return dim
28+
29+
30+
def get_channel_dim(fmt: FormatT):
31+
fmt = Format(fmt)
32+
if fmt is Format.NHWC:
33+
dim = 3
34+
elif fmt is Format.NLC:
35+
dim = 2
36+
else:
37+
dim = 1
38+
return dim
39+
40+
41+
def nchw_to(x: torch.Tensor, fmt: Format):
42+
if fmt == Format.NHWC:
43+
x = x.permute(0, 2, 3, 1)
44+
elif fmt == Format.NLC:
45+
x = x.flatten(2).transpose(1, 2)
46+
elif fmt == Format.NCL:
47+
x = x.flatten(2)
48+
return x
49+
50+
51+
def nhwc_to(x: torch.Tensor, fmt: Format):
52+
if fmt == Format.NCHW:
53+
x = x.permute(0, 3, 1, 2)
54+
elif fmt == Format.NLC:
55+
x = x.flatten(1, 2)
56+
elif fmt == Format.NCL:
57+
x = x.flatten(1, 2).transpose(1, 2)
58+
return x

0 commit comments

Comments
 (0)