1
+ """
2
+ ResNet50
3
+ 2017/12/06
4
+ """
5
+
6
+ import tensorflow as tf
7
+ from tensorflow .python .training import moving_averages
8
+
9
+ fc_initializer = tf .contrib .layers .xavier_initializer
10
+ conv2d_initializer = tf .contrib .layers .xavier_initializer_conv2d
11
+
12
+ # create weight variable
13
+ def create_var (name , shape , initializer , trainable = True ):
14
+ return tf .get_variable (name , shape = shape , dtype = tf .float32 ,
15
+ initializer = initializer , trainable = trainable )
16
+
17
+ # conv2d layer
18
+ def conv2d (x , num_outputs , kernel_size , stride = 1 , scope = "conv2d" ):
19
+ num_inputs = x .get_shape ()[- 1 ]
20
+ with tf .variable_scope (scope ):
21
+ kernel = create_var ("kernel" , [kernel_size , kernel_size ,
22
+ num_inputs , num_outputs ],
23
+ conv2d_initializer ())
24
+ return tf .nn .conv2d (x , kernel , strides = [1 , stride , stride , 1 ],
25
+ padding = "SAME" )
26
+
27
+ # fully connected layer
28
+ def fc (x , num_outputs , scope = "fc" ):
29
+ num_inputs = x .get_shape ()[- 1 ]
30
+ with tf .variable_scope (scope ):
31
+ weight = create_var ("weight" , [num_inputs , num_outputs ],
32
+ fc_initializer ())
33
+ bias = create_var ("bias" , [num_outputs ,],
34
+ tf .zeros_initializer ())
35
+ return tf .nn .xw_plus_b (x , weight , bias )
36
+
37
+
38
+ # batch norm layer
39
+ def batch_norm (x , decay = 0.999 , epsilon = 1e-03 , is_training = True ,
40
+ scope = "scope" ):
41
+ x_shape = x .get_shape ()
42
+ num_inputs = x_shape [- 1 ]
43
+ reduce_dims = list (range (len (x_shape ) - 1 ))
44
+ with tf .variable_scope (scope ):
45
+ beta = create_var ("beta" , [num_inputs ,],
46
+ initializer = tf .zeros_initializer ())
47
+ gamma = create_var ("gamma" , [num_inputs ,],
48
+ initializer = tf .ones_initializer ())
49
+ # for inference
50
+ moving_mean = create_var ("moving_mean" , [num_inputs ,],
51
+ initializer = tf .zeros_initializer (),
52
+ trainable = False )
53
+ moving_variance = create_var ("moving_variance" , [num_inputs ],
54
+ initializer = tf .ones_initializer (),
55
+ trainable = False )
56
+ if is_training :
57
+ mean , variance = tf .nn .moments (x , axes = reduce_dims )
58
+ update_move_mean = moving_averages .assign_moving_average (moving_mean ,
59
+ mean , decay = decay )
60
+ update_move_variance = moving_averages .assign_moving_average (moving_variance ,
61
+ variance , decay = decay )
62
+ tf .add_to_collection (tf .GraphKeys .UPDATE_OPS , update_move_mean )
63
+ tf .add_to_collection (tf .GraphKeys .UPDATE_OPS , update_move_variance )
64
+ else :
65
+ mean , variance = moving_mean , moving_variance
66
+ return tf .nn .batch_normalization (x , mean , variance , beta , gamma , epsilon )
67
+
68
+
69
+ # avg pool layer
70
+ def avg_pool (x , pool_size , scope ):
71
+ with tf .variable_scope (scope ):
72
+ return tf .nn .avg_pool (x , [1 , pool_size , pool_size , 1 ],
73
+ strides = [1 , pool_size , pool_size , 1 ], padding = "VALID" )
74
+
75
+ # max pool layer
76
+ def max_pool (x , pool_size , stride , scope ):
77
+ with tf .variable_scope (scope ):
78
+ return tf .nn .max_pool (x , [1 , pool_size , pool_size , 1 ],
79
+ [1 , stride , stride , 1 ], padding = "SAME" )
80
+
81
+ class ResNet50 (object ):
82
+ def __init__ (self , inputs , num_classes = 1000 , is_training = True ,
83
+ scope = "resnet50" ):
84
+ self .inputs = inputs
85
+ self .is_training = is_training
86
+ self .num_classes = num_classes
87
+
88
+ with tf .variable_scope (scope ):
89
+ # construct the model
90
+ net = conv2d (inputs , 64 , 7 , 2 , scope = "conv1" ) # -> [batch, 112, 112, 64]
91
+ net = tf .nn .relu (batch_norm (net , is_training = self .is_training , scope = "bn1" ))
92
+ net = max_pool (net , 3 , 2 , scope = "maxpool1" ) # -> [batch, 56, 56, 64]
93
+ net = self ._block (net , 256 , 3 , init_stride = 1 , is_training = self .is_training ,
94
+ scope = "block2" ) # -> [batch, 56, 56, 256]
95
+ net = self ._block (net , 512 , 4 , is_training = self .is_training , scope = "block3" )
96
+ # -> [batch, 28, 28, 512]
97
+ net = self ._block (net , 1024 , 6 , is_training = self .is_training , scope = "block4" )
98
+ # -> [batch, 14, 14, 1024]
99
+ net = self ._block (net , 2048 , 3 , is_training = self .is_training , scope = "block5" )
100
+ # -> [batch, 7, 7, 2048]
101
+ net = avg_pool (net , 7 , scope = "avgpool5" ) # -> [batch, 1, 1, 2048]
102
+ net = tf .squeeze (net , [1 , 2 ], name = "SpatialSqueeze" ) # -> [batch, 2048]
103
+ self .logits = fc (net , self .num_classes , "fc6" ) # -> [batch, num_classes]
104
+ self .predictions = tf .nn .softmax (self .logits )
105
+
106
+
107
+ def _block (self , x , n_out , n , init_stride = 2 , is_training = True , scope = "block" ):
108
+ with tf .variable_scope (scope ):
109
+ h_out = n_out // 4
110
+ out = self ._bottleneck (x , h_out , n_out , stride = init_stride ,
111
+ is_training = is_training , scope = "bottlencek1" )
112
+ for i in range (1 , n ):
113
+ out = self ._bottleneck (out , h_out , n_out , is_training = is_training ,
114
+ scope = ("bottlencek%s" % (i + 1 )))
115
+ return out
116
+
117
+ def _bottleneck (self , x , h_out , n_out , stride = None , is_training = True , scope = "bottleneck" ):
118
+ """ A residual bottleneck unit"""
119
+ n_in = x .get_shape ()[- 1 ]
120
+ if stride is None :
121
+ stride = 1 if n_in == n_out else 2
122
+
123
+ with tf .variable_scope (scope ):
124
+ h = conv2d (x , h_out , 1 , stride = stride , scope = "conv_1" )
125
+ h = batch_norm (h , is_training = is_training , scope = "bn_1" )
126
+ h = tf .nn .relu (h )
127
+ h = conv2d (h , h_out , 3 , stride = 1 , scope = "conv_2" )
128
+ h = batch_norm (h , is_training = is_training , scope = "bn_2" )
129
+ h = tf .nn .relu (h )
130
+ h = conv2d (h , n_out , 1 , stride = 1 , scope = "conv_3" )
131
+ h = batch_norm (h , is_training = is_training , scope = "bn_3" )
132
+
133
+ if n_in != n_out :
134
+ shortcut = conv2d (x , n_out , 1 , stride = stride , scope = "conv_4" )
135
+ shortcut = batch_norm (shortcut , is_training = is_training , scope = "bn_4" )
136
+ else :
137
+ shortcut = x
138
+ return tf .nn .relu (shortcut + h )
139
+
140
+ if __name__ == "__main__" :
141
+ x = tf .random_normal ([32 , 224 , 224 , 3 ])
142
+ resnet50 = ResNet50 (x )
143
+ print (resnet50 .logits )
0 commit comments