Skip to content

Commit 6fe1531

Browse files
committed
fix BN axis error
1 parent 1a3257a commit 6fe1531

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

DenseNet/densenet.py

+26-21
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
import keras.backend as K
1010

1111

12-
def conv_factory(x, nb_filter, dropout_rate=None, weight_decay=1E-4):
12+
def conv_factory(x, concat_axis, nb_filter,
13+
dropout_rate=None, weight_decay=1E-4):
1314
"""Apply BatchNorm, Relu 3x3Conv2D, optional dropout
1415
1516
:param x: Input keras network
17+
:param concat_axis: int -- index of contatenate axis
1618
:param nb_filter: int -- number of filters
1719
:param dropout_rate: int -- dropout rate
1820
:param weight_decay: int -- weight decay factor
@@ -21,7 +23,7 @@ def conv_factory(x, nb_filter, dropout_rate=None, weight_decay=1E-4):
2123
:rtype: keras network
2224
"""
2325

24-
x = BatchNormalization(axis=1,
26+
x = BatchNormalization(axis=concat_axis,
2527
gamma_regularizer=l2(weight_decay),
2628
beta_regularizer=l2(weight_decay))(x)
2729
x = Activation('relu')(x)
@@ -36,10 +38,12 @@ def conv_factory(x, nb_filter, dropout_rate=None, weight_decay=1E-4):
3638
return x
3739

3840

39-
def transition(x, nb_filter, dropout_rate=None, weight_decay=1E-4):
41+
def transition(x, concat_axis, nb_filter,
42+
dropout_rate=None, weight_decay=1E-4):
4043
"""Apply BatchNorm, Relu 1x1Conv2D, optional dropout and Maxpooling2D
4144
4245
:param x: keras model
46+
:param concat_axis: int -- index of contatenate axis
4347
:param nb_filter: int -- number of filters
4448
:param dropout_rate: int -- dropout rate
4549
:param weight_decay: int -- weight decay factor
@@ -49,7 +53,7 @@ def transition(x, nb_filter, dropout_rate=None, weight_decay=1E-4):
4953
5054
"""
5155

52-
x = BatchNormalization(axis=1,
56+
x = BatchNormalization(axis=concat_axis,
5357
gamma_regularizer=l2(weight_decay),
5458
beta_regularizer=l2(weight_decay))(x)
5559
x = Activation('relu')(x)
@@ -65,12 +69,13 @@ def transition(x, nb_filter, dropout_rate=None, weight_decay=1E-4):
6569
return x
6670

6771

68-
def denseblock(x, nb_layers, nb_filter, growth_rate,
72+
def denseblock(x, concat_axis, nb_layers, nb_filter, growth_rate,
6973
dropout_rate=None, weight_decay=1E-4):
7074
"""Build a denseblock where the output of each
7175
conv_factory is fed to subsequent ones
7276
7377
:param x: keras model
78+
:param concat_axis: int -- index of contatenate axis
7479
:param nb_layers: int -- the number of layers of conv_
7580
factory to append to the model.
7681
:param nb_filter: int -- number of filters
@@ -84,26 +89,23 @@ def denseblock(x, nb_layers, nb_filter, growth_rate,
8489

8590
list_feat = [x]
8691

87-
if K.image_dim_ordering() == "th":
88-
concat_axis = 1
89-
elif K.image_dim_ordering() == "tf":
90-
concat_axis = -1
91-
9292
for i in range(nb_layers):
93-
x = conv_factory(x, growth_rate, dropout_rate, weight_decay)
93+
x = conv_factory(x, concat_axis, growth_rate,
94+
dropout_rate, weight_decay)
9495
list_feat.append(x)
9596
x = Concatenate(axis=concat_axis)(list_feat)
9697
nb_filter += growth_rate
9798

9899
return x, nb_filter
99100

100101

101-
def denseblock_altern(x, nb_layers, nb_filter, growth_rate,
102+
def denseblock_altern(x, concat_axis, nb_layers, nb_filter, growth_rate,
102103
dropout_rate=None, weight_decay=1E-4):
103104
"""Build a denseblock where the output of each conv_factory
104105
is fed to subsequent ones. (Alternative of a above)
105106
106107
:param x: keras model
108+
:param concat_axis: int -- index of contatenate axis
107109
:param nb_layers: int -- the number of layers of conv_
108110
factory to append to the model.
109111
:param nb_filter: int -- number of filters
@@ -117,13 +119,9 @@ def denseblock_altern(x, nb_layers, nb_filter, growth_rate,
117119
above is that the one above
118120
"""
119121

120-
if K.image_dim_ordering() == "th":
121-
concat_axis = 1
122-
elif K.image_dim_ordering() == "tf":
123-
concat_axis = -1
124-
125122
for i in range(nb_layers):
126-
merge_tensor = conv_factory(x, growth_rate, dropout_rate, weight_decay)
123+
merge_tensor = conv_factory(x, concat_axis, growth_rate,
124+
dropout_rate, weight_decay)
127125
x = Concatenate(axis=concat_axis)([merge_tensor, x])
128126
nb_filter += growth_rate
129127

@@ -147,6 +145,11 @@ def DenseNet(nb_classes, img_dim, depth, nb_dense_block, growth_rate,
147145
:rtype: keras model
148146
149147
"""
148+
149+
if K.image_dim_ordering() == "th":
150+
concat_axis = 1
151+
elif K.image_dim_ordering() == "tf":
152+
concat_axis = -1
150153

151154
model_input = Input(shape=img_dim)
152155

@@ -165,19 +168,21 @@ def DenseNet(nb_classes, img_dim, depth, nb_dense_block, growth_rate,
165168

166169
# Add dense blocks
167170
for block_idx in range(nb_dense_block - 1):
168-
x, nb_filter = denseblock(x, nb_layers, nb_filter, growth_rate,
171+
x, nb_filter = denseblock(x, concat_axis, nb_layers,
172+
nb_filter, growth_rate,
169173
dropout_rate=dropout_rate,
170174
weight_decay=weight_decay)
171175
# add transition
172176
x = transition(x, nb_filter, dropout_rate=dropout_rate,
173177
weight_decay=weight_decay)
174178

175179
# The last denseblock does not have a transition
176-
x, nb_filter = denseblock(x, nb_layers, nb_filter, growth_rate,
180+
x, nb_filter = denseblock(x, concat_axis, nb_layers,
181+
nb_filter, growth_rate,
177182
dropout_rate=dropout_rate,
178183
weight_decay=weight_decay)
179184

180-
x = BatchNormalization(axis=1,
185+
x = BatchNormalization(axis=concat_axis,
181186
gamma_regularizer=l2(weight_decay),
182187
beta_regularizer=l2(weight_decay))(x)
183188
x = Activation('relu')(x)

0 commit comments

Comments
 (0)