This repository was archived by the owner on Jan 13, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 282
/
Copy pathresnext.py
183 lines (154 loc) · 7.07 KB
/
resnext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ResNeXt (50, 101, 152)
# Paper: https://arxiv.org/pdf/1611.05431.pdf
import tensorflow as tf
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Conv2D, MaxPooling2D, ReLU, BatchNormalization, Add
from tensorflow.keras.layers import Concatenate, Dense, GlobalAveragePooling2D, Lambda
def stem(inputs):
""" Construct the Stem Convolution Group
inputs : input vector
"""
x = Conv2D(64, (7, 7), strides=(2, 2), padding='same', kernel_initializer='he_normal', use_bias=False)(inputs)
x = BatchNormalization()(x)
x = ReLU()(x)
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
return x
def learner(x, groups, cardinality=32):
""" Construct the Learner
x : input to the learner
groups : list of groups: filters in, filters out, number of blocks
cardinality: width of group convolution
"""
# First ResNeXt Group (not-strided)
filters_in, filters_out, n_blocks = groups.pop(0)
x = group(x, filters_in, filters_out, n_blocks, strides=(1, 1), cardinality=cardinality)
# Remaining ResNeXt groups
for filters_in, filters_out, n_blocks in groups:
x = group(x, filters_in, filters_out, n_blocks, cardinality=cardinality)
return x
def group(x, filters_in, filters_out, n_blocks, cardinality=32, strides=(2, 2)):
""" Construct a Residual group
x : input to the group
filters_in : number of filters (channels) at the input convolution
filters_out: number of filters (channels) at the output convolution
cardinality: width of group convolution
strides : whether its a strided convolution
"""
# Double the size of filters to fit the first Residual Group
# Reduce feature maps by 75% (strides=2, 2) to fit the next Residual Group
x = projection_block(x, filters_in, filters_out, strides=strides, cardinality=cardinality)
# Remaining blocks
for _ in range(n_blocks):
x = identity_block(x, filters_in, filters_out, cardinality=cardinality)
return x
def identity_block(x, filters_in, filters_out, cardinality=32):
""" Construct a ResNeXT block with identity link
x : input to block
filters_in : number of filters (channels) at the input convolution
filters_out: number of filters (channels) at the output convolution
cardinality: width of group convolution
"""
# Remember the input
shortcut = x
# Dimensionality Reduction
x = Conv2D(filters_in, (1, 1), strides=(1, 1),
padding='same', kernel_initializer='he_normal', use_bias=False)(shortcut)
x = BatchNormalization()(x)
x = ReLU()(x)
# Cardinality (Wide) Layer (split-transform)
filters_card = filters_in // cardinality
groups = []
for i in range(cardinality):
group = Lambda(lambda z: z[:, :, :, i * filters_card:i *
filters_card + filters_card])(x)
groups.append(Conv2D(filters_card, (3, 3), strides=(1, 1),
padding='same', kernel_initializer='he_normal', use_bias=False)(group))
# Concatenate the outputs of the cardinality layer together (merge)
x = Concatenate()(groups)
x = BatchNormalization()(x)
x = ReLU()(x)
# Dimensionality restoration
x = Conv2D(filters_out, (1, 1), strides=(1, 1),
padding='same', kernel_initializer='he_normal', use_bias=False)(x)
x = BatchNormalization()(x)
# Identity Link: Add the shortcut (input) to the output of the block
x = Add()([shortcut, x])
x = ReLU()(x)
return x
def projection_block(x, filters_in, filters_out, cardinality=32, strides=(2, 2)):
""" Construct a ResNeXT block with projection shortcut
x : input to the block
filters_in : number of filters (channels) at the input convolution
filters_out: number of filters (channels) at the output convolution
cardinality: width of group convolution
strides : whether entry convolution is strided (i.e., (2, 2) vs (1, 1))
"""
# Construct the projection shortcut
# Increase filters by 2X to match shape when added to output of block
shortcut = Conv2D(filters_out, (1, 1), strides=strides,
padding='same', kernel_initializer='he_normal')(x)
shortcut = BatchNormalization()(shortcut)
# Dimensionality Reduction
x = Conv2D(filters_in, (1, 1), strides=(1, 1),
padding='same', kernel_initializer='he_normal', use_bias=False)(x)
x = BatchNormalization()(x)
x = ReLU()(x)
# Cardinality (Wide) Layer (split-transform)
filters_card = filters_in // cardinality
groups = []
for i in range(cardinality):
group = Lambda(lambda z: z[:, :, :, i * filters_card:i *
filters_card + filters_card])(x)
groups.append(Conv2D(filters_card, (3, 3), strides=strides,
padding='same', kernel_initializer='he_normal', use_bias=False)(group))
# Concatenate the outputs of the cardinality layer together (merge)
x = Concatenate()(groups)
x = BatchNormalization()(x)
x = ReLU()(x)
# Dimensionality restoration
x = Conv2D(filters_out, (1, 1), strides=(1, 1),
padding='same', kernel_initializer='he_normal', use_bias=False)(x)
x = BatchNormalization()(x)
# Identity Link: Add the shortcut (input) to the output of the block
x = Add()([shortcut, x])
x = ReLU()(x)
return x
def classifier(x, n_classes):
""" Construct the Classifier
x : input to the classifier
n_classes : number of output classes
"""
# Final Dense Outputting Layer
x = GlobalAveragePooling2D()(x)
outputs = Dense(n_classes, activation='softmax', kernel_initializer='he_normal')(x)
return outputs
# Meta-parameter: number of filters in, out and number of blocks
groups = { 50 : [ (128, 256, 3), (256, 512, 4), (512, 1024, 6), (1024, 2048, 3)], # ResNeXt 50
101: [ (128, 256, 3), (256, 512, 4), (512, 1024, 23), (1024, 2048, 3)], # ResNeXt 101
152: [ (128, 256, 3), (256, 512, 8), (512, 1024, 36), (1024, 2048, 3)] # ResNeXt 152
}
# Meta-parameter: width of group convolution
cardinality = 32
# The input tensor
inputs = Input(shape=(224, 224, 3))
# The Stem Group
x = stem(inputs)
# The Learner
x = learner(x, groups[50], cardinality)
# The Classifier for 1000 classes
outputs = classifier(x, 1000)
# Instantiate the Model
model = Model(inputs, outputs)