9
9
10
10
Hacked together by / Copyright 2020 Ross Wightman
11
11
"""
12
+ from typing import Optional , Tuple , Union
13
+
12
14
import torch
13
15
import torch .nn as nn
14
16
import torch .nn .functional as F
15
17
18
+ from .format import get_spatial_dim , get_channel_dim
19
+
20
+ _int_tuple_2_t = Union [int , Tuple [int , int ]]
21
+
16
22
17
23
def adaptive_pool_feat_mult (pool_type = 'avg' ):
18
- if pool_type == 'catavgmax' :
24
+ if pool_type . endswith ( 'catavgmax' ) :
19
25
return 2
20
26
else :
21
27
return 1
22
28
23
29
24
- def adaptive_avgmax_pool2d (x , output_size = 1 ):
30
+ def adaptive_avgmax_pool2d (x , output_size : _int_tuple_2_t = 1 ):
25
31
x_avg = F .adaptive_avg_pool2d (x , output_size )
26
32
x_max = F .adaptive_max_pool2d (x , output_size )
27
33
return 0.5 * (x_avg + x_max )
28
34
29
35
30
- def adaptive_catavgmax_pool2d (x , output_size = 1 ):
36
+ def adaptive_catavgmax_pool2d (x , output_size : _int_tuple_2_t = 1 ):
31
37
x_avg = F .adaptive_avg_pool2d (x , output_size )
32
38
x_max = F .adaptive_max_pool2d (x , output_size )
33
39
return torch .cat ((x_avg , x_max ), 1 )
34
40
35
41
36
- def select_adaptive_pool2d (x , pool_type = 'avg' , output_size = 1 ):
42
+ def select_adaptive_pool2d (x , pool_type = 'avg' , output_size : _int_tuple_2_t = 1 ):
37
43
"""Selectable global pooling function with dynamic input kernel size
38
44
"""
39
45
if pool_type == 'avg' :
@@ -49,17 +55,56 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
49
55
return x
50
56
51
57
52
- class FastAdaptiveAvgPool2d (nn .Module ):
53
- def __init__ (self , flatten = False ):
54
- super (FastAdaptiveAvgPool2d , self ).__init__ ()
58
+ class FastAdaptiveAvgPool (nn .Module ):
59
+ def __init__ (self , flatten : bool = False , input_fmt : F = 'NCHW' ):
60
+ super (FastAdaptiveAvgPool , self ).__init__ ()
61
+ self .flatten = flatten
62
+ self .dim = get_spatial_dim (input_fmt )
63
+
64
+ def forward (self , x ):
65
+ return x .mean (self .dim , keepdim = not self .flatten )
66
+
67
+
68
+ class FastAdaptiveMaxPool (nn .Module ):
69
+ def __init__ (self , flatten : bool = False , input_fmt : str = 'NCHW' ):
70
+ super (FastAdaptiveMaxPool , self ).__init__ ()
55
71
self .flatten = flatten
72
+ self .dim = get_spatial_dim (input_fmt )
73
+
74
+ def forward (self , x ):
75
+ return x .amax (self .dim , keepdim = not self .flatten )
76
+
77
+
78
+ class FastAdaptiveAvgMaxPool (nn .Module ):
79
+ def __init__ (self , flatten : bool = False , input_fmt : str = 'NCHW' ):
80
+ super (FastAdaptiveAvgMaxPool , self ).__init__ ()
81
+ self .flatten = flatten
82
+ self .dim = get_spatial_dim (input_fmt )
83
+
84
+ def forward (self , x ):
85
+ x_avg = x .mean (self .dim , keepdim = not self .flatten )
86
+ x_max = x .amax (self .dim , keepdim = not self .flatten )
87
+ return 0.5 * x_avg + 0.5 * x_max
88
+
89
+
90
+ class FastAdaptiveCatAvgMaxPool (nn .Module ):
91
+ def __init__ (self , flatten : bool = False , input_fmt : str = 'NCHW' ):
92
+ super (FastAdaptiveCatAvgMaxPool , self ).__init__ ()
93
+ self .flatten = flatten
94
+ self .dim_reduce = get_spatial_dim (input_fmt )
95
+ if flatten :
96
+ self .dim_cat = 1
97
+ else :
98
+ self .dim_cat = get_channel_dim (input_fmt )
56
99
57
100
def forward (self , x ):
58
- return x .mean ((2 , 3 ), keepdim = not self .flatten )
101
+ x_avg = x .mean (self .dim_reduce , keepdim = not self .flatten )
102
+ x_max = x .amax (self .dim_reduce , keepdim = not self .flatten )
103
+ return torch .cat ((x_avg , x_max ), self .dim_cat )
59
104
60
105
61
106
class AdaptiveAvgMaxPool2d (nn .Module ):
62
- def __init__ (self , output_size = 1 ):
107
+ def __init__ (self , output_size : _int_tuple_2_t = 1 ):
63
108
super (AdaptiveAvgMaxPool2d , self ).__init__ ()
64
109
self .output_size = output_size
65
110
@@ -68,7 +113,7 @@ def forward(self, x):
68
113
69
114
70
115
class AdaptiveCatAvgMaxPool2d (nn .Module ):
71
- def __init__ (self , output_size = 1 ):
116
+ def __init__ (self , output_size : _int_tuple_2_t = 1 ):
72
117
super (AdaptiveCatAvgMaxPool2d , self ).__init__ ()
73
118
self .output_size = output_size
74
119
@@ -79,26 +124,41 @@ def forward(self, x):
79
124
class SelectAdaptivePool2d (nn .Module ):
80
125
"""Selectable global pooling layer with dynamic input kernel size
81
126
"""
82
- def __init__ (self , output_size = 1 , pool_type = 'fast' , flatten = False ):
127
+ def __init__ (
128
+ self ,
129
+ output_size : _int_tuple_2_t = 1 ,
130
+ pool_type : str = 'fast' ,
131
+ flatten : bool = False ,
132
+ input_fmt : str = 'NCHW' ,
133
+ ):
83
134
super (SelectAdaptivePool2d , self ).__init__ ()
135
+ assert input_fmt in ('NCHW' , 'NHWC' )
84
136
self .pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
85
- self .flatten = nn .Flatten (1 ) if flatten else nn .Identity ()
86
- if pool_type == '' :
137
+ if not pool_type :
87
138
self .pool = nn .Identity () # pass through
88
- elif pool_type == 'fast' :
89
- assert output_size == 1
90
- self .pool = FastAdaptiveAvgPool2d (flatten )
139
+ self .flatten = nn .Flatten (1 ) if flatten else nn .Identity ()
140
+ elif pool_type .startswith ('fast' ) or input_fmt != 'NCHW' :
141
+ assert output_size == 1 , 'Fast pooling and non NCHW input formats require output_size == 1.'
142
+ if pool_type .endswith ('avgmax' ):
143
+ self .pool = FastAdaptiveAvgMaxPool (flatten , input_fmt = input_fmt )
144
+ elif pool_type .endswith ('catavgmax' ):
145
+ self .pool = FastAdaptiveCatAvgMaxPool (flatten , input_fmt = input_fmt )
146
+ elif pool_type .endswith ('max' ):
147
+ self .pool = FastAdaptiveMaxPool (flatten , input_fmt = input_fmt )
148
+ else :
149
+ self .pool = FastAdaptiveAvgPool (flatten , input_fmt = input_fmt )
91
150
self .flatten = nn .Identity ()
92
- elif pool_type == 'avg' :
93
- self .pool = nn .AdaptiveAvgPool2d (output_size )
94
- elif pool_type == 'avgmax' :
95
- self .pool = AdaptiveAvgMaxPool2d (output_size )
96
- elif pool_type == 'catavgmax' :
97
- self .pool = AdaptiveCatAvgMaxPool2d (output_size )
98
- elif pool_type == 'max' :
99
- self .pool = nn .AdaptiveMaxPool2d (output_size )
100
151
else :
101
- assert False , 'Invalid pool type: %s' % pool_type
152
+ assert input_fmt == 'NCHW'
153
+ if pool_type == 'avgmax' :
154
+ self .pool = AdaptiveAvgMaxPool2d (output_size )
155
+ elif pool_type == 'catavgmax' :
156
+ self .pool = AdaptiveCatAvgMaxPool2d (output_size )
157
+ elif pool_type == 'max' :
158
+ self .pool = nn .AdaptiveMaxPool2d (output_size )
159
+ else :
160
+ self .pool = nn .AdaptiveAvgPool2d (output_size )
161
+ self .flatten = nn .Flatten (1 ) if flatten else nn .Identity ()
102
162
103
163
def is_identity (self ):
104
164
return not self .pool_type
0 commit comments