9
9
import keras .backend as K
10
10
11
11
12
- def conv_factory (x , nb_filter , dropout_rate = None , weight_decay = 1E-4 ):
12
+ def conv_factory (x , concat_axis , nb_filter ,
13
+ dropout_rate = None , weight_decay = 1E-4 ):
13
14
"""Apply BatchNorm, Relu 3x3Conv2D, optional dropout
14
15
15
16
:param x: Input keras network
17
+ :param concat_axis: int -- index of contatenate axis
16
18
:param nb_filter: int -- number of filters
17
19
:param dropout_rate: int -- dropout rate
18
20
:param weight_decay: int -- weight decay factor
@@ -21,7 +23,7 @@ def conv_factory(x, nb_filter, dropout_rate=None, weight_decay=1E-4):
21
23
:rtype: keras network
22
24
"""
23
25
24
- x = BatchNormalization (axis = 1 ,
26
+ x = BatchNormalization (axis = concat_axis ,
25
27
gamma_regularizer = l2 (weight_decay ),
26
28
beta_regularizer = l2 (weight_decay ))(x )
27
29
x = Activation ('relu' )(x )
@@ -36,10 +38,12 @@ def conv_factory(x, nb_filter, dropout_rate=None, weight_decay=1E-4):
36
38
return x
37
39
38
40
39
- def transition (x , nb_filter , dropout_rate = None , weight_decay = 1E-4 ):
41
+ def transition (x , concat_axis , nb_filter ,
42
+ dropout_rate = None , weight_decay = 1E-4 ):
40
43
"""Apply BatchNorm, Relu 1x1Conv2D, optional dropout and Maxpooling2D
41
44
42
45
:param x: keras model
46
+ :param concat_axis: int -- index of contatenate axis
43
47
:param nb_filter: int -- number of filters
44
48
:param dropout_rate: int -- dropout rate
45
49
:param weight_decay: int -- weight decay factor
@@ -49,7 +53,7 @@ def transition(x, nb_filter, dropout_rate=None, weight_decay=1E-4):
49
53
50
54
"""
51
55
52
- x = BatchNormalization (axis = 1 ,
56
+ x = BatchNormalization (axis = concat_axis ,
53
57
gamma_regularizer = l2 (weight_decay ),
54
58
beta_regularizer = l2 (weight_decay ))(x )
55
59
x = Activation ('relu' )(x )
@@ -65,12 +69,13 @@ def transition(x, nb_filter, dropout_rate=None, weight_decay=1E-4):
65
69
return x
66
70
67
71
68
- def denseblock (x , nb_layers , nb_filter , growth_rate ,
72
+ def denseblock (x , concat_axis , nb_layers , nb_filter , growth_rate ,
69
73
dropout_rate = None , weight_decay = 1E-4 ):
70
74
"""Build a denseblock where the output of each
71
75
conv_factory is fed to subsequent ones
72
76
73
77
:param x: keras model
78
+ :param concat_axis: int -- index of contatenate axis
74
79
:param nb_layers: int -- the number of layers of conv_
75
80
factory to append to the model.
76
81
:param nb_filter: int -- number of filters
@@ -84,26 +89,23 @@ def denseblock(x, nb_layers, nb_filter, growth_rate,
84
89
85
90
list_feat = [x ]
86
91
87
- if K .image_dim_ordering () == "th" :
88
- concat_axis = 1
89
- elif K .image_dim_ordering () == "tf" :
90
- concat_axis = - 1
91
-
92
92
for i in range (nb_layers ):
93
- x = conv_factory (x , growth_rate , dropout_rate , weight_decay )
93
+ x = conv_factory (x , concat_axis , growth_rate ,
94
+ dropout_rate , weight_decay )
94
95
list_feat .append (x )
95
96
x = Concatenate (axis = concat_axis )(list_feat )
96
97
nb_filter += growth_rate
97
98
98
99
return x , nb_filter
99
100
100
101
101
- def denseblock_altern (x , nb_layers , nb_filter , growth_rate ,
102
+ def denseblock_altern (x , concat_axis , nb_layers , nb_filter , growth_rate ,
102
103
dropout_rate = None , weight_decay = 1E-4 ):
103
104
"""Build a denseblock where the output of each conv_factory
104
105
is fed to subsequent ones. (Alternative of a above)
105
106
106
107
:param x: keras model
108
+ :param concat_axis: int -- index of contatenate axis
107
109
:param nb_layers: int -- the number of layers of conv_
108
110
factory to append to the model.
109
111
:param nb_filter: int -- number of filters
@@ -117,13 +119,9 @@ def denseblock_altern(x, nb_layers, nb_filter, growth_rate,
117
119
above is that the one above
118
120
"""
119
121
120
- if K .image_dim_ordering () == "th" :
121
- concat_axis = 1
122
- elif K .image_dim_ordering () == "tf" :
123
- concat_axis = - 1
124
-
125
122
for i in range (nb_layers ):
126
- merge_tensor = conv_factory (x , growth_rate , dropout_rate , weight_decay )
123
+ merge_tensor = conv_factory (x , concat_axis , growth_rate ,
124
+ dropout_rate , weight_decay )
127
125
x = Concatenate (axis = concat_axis )([merge_tensor , x ])
128
126
nb_filter += growth_rate
129
127
@@ -147,6 +145,11 @@ def DenseNet(nb_classes, img_dim, depth, nb_dense_block, growth_rate,
147
145
:rtype: keras model
148
146
149
147
"""
148
+
149
+ if K .image_dim_ordering () == "th" :
150
+ concat_axis = 1
151
+ elif K .image_dim_ordering () == "tf" :
152
+ concat_axis = - 1
150
153
151
154
model_input = Input (shape = img_dim )
152
155
@@ -165,19 +168,21 @@ def DenseNet(nb_classes, img_dim, depth, nb_dense_block, growth_rate,
165
168
166
169
# Add dense blocks
167
170
for block_idx in range (nb_dense_block - 1 ):
168
- x , nb_filter = denseblock (x , nb_layers , nb_filter , growth_rate ,
171
+ x , nb_filter = denseblock (x , concat_axis , nb_layers ,
172
+ nb_filter , growth_rate ,
169
173
dropout_rate = dropout_rate ,
170
174
weight_decay = weight_decay )
171
175
# add transition
172
176
x = transition (x , nb_filter , dropout_rate = dropout_rate ,
173
177
weight_decay = weight_decay )
174
178
175
179
# The last denseblock does not have a transition
176
- x , nb_filter = denseblock (x , nb_layers , nb_filter , growth_rate ,
180
+ x , nb_filter = denseblock (x , concat_axis , nb_layers ,
181
+ nb_filter , growth_rate ,
177
182
dropout_rate = dropout_rate ,
178
183
weight_decay = weight_decay )
179
184
180
- x = BatchNormalization (axis = 1 ,
185
+ x = BatchNormalization (axis = concat_axis ,
181
186
gamma_regularizer = l2 (weight_decay ),
182
187
beta_regularizer = l2 (weight_decay ))(x )
183
188
x = Activation ('relu' )(x )
0 commit comments