1
- # /usr/bin/python
1
+ #! /usr/bin/python
2
2
# -*- coding: utf-8 -*-
3
3
4
+ import numpy as np
4
5
import tensorflow as tf
6
+ import tensorlayer as tl
5
7
from tensorflow .python .training import moving_averages
6
-
7
8
from tensorlayer import logging
8
- from tensorlayer .decorators import deprecated_alias
9
9
from tensorlayer .layers .core import Layer
10
10
from tensorlayer .layers .utils import (quantize_active_overflow , quantize_weight_overflow )
11
11
@@ -22,8 +22,6 @@ class QuanConv2dWithBN(Layer):
22
22
23
23
Parameters
24
24
----------
25
- prev_layer : :class:`Layer`
26
- Previous layer.
27
25
n_filter : int
28
26
The number of filters.
29
27
filter_size : tuple of int
@@ -51,49 +49,33 @@ class QuanConv2dWithBN(Layer):
51
49
The bits of this layer's parameter
52
50
bitA : int
53
51
The bits of the output of previous layer
54
- decay : float
55
- A decay factor for `ExponentialMovingAverage`.
56
- Suggest to use a large value for large dataset.
57
- epsilon : float
58
- Eplison.
59
- is_train : boolean
60
- Is being used for training or inference.
61
- beta_init : initializer or None
62
- The initializer for initializing beta, if None, skip beta.
63
- Usually you should not skip beta unless you know what happened.
64
- gamma_init : initializer or None
65
- The initializer for initializing gamma, if None, skip gamma.
66
52
use_gemm : boolean
67
53
If True, use gemm instead of ``tf.matmul`` for inferencing. (TODO).
68
54
W_init : initializer
69
55
The initializer for the the weight matrix.
70
56
W_init_args : dictionary
71
57
The arguments for the weight matrix initializer.
72
- use_cudnn_on_gpu : bool
73
- Default is False.
74
58
data_format : str
75
59
"NHWC" or "NCHW", default is "NHWC".
60
+ dilation_rate : tuple of int
61
+ Specifying the dilation rate to use for dilated convolution.
62
+ in_channels : int
63
+ The number of in channels.
76
64
name : str
77
65
A unique layer name.
78
66
79
67
Examples
80
68
---------
81
- >>> import tensorflow as tf
82
69
>>> import tensorlayer as tl
83
- >>> x = tf.placeholder(tf.float32, [None, 256, 256, 3])
84
- >>> net = tl.layers.InputLayer(x, name='input')
85
- >>> net = tl.layers.QuanConv2dWithBN(net, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', is_train=is_train, bitW=bitW, bitA=bitA, name='qcnnbn1')
86
- >>> net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool1')
87
- ...
88
- >>> net = tl.layers.QuanConv2dWithBN(net, 64, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, is_train=is_train, bitW=bitW, bitA=bitA, name='qcnnbn2')
89
- >>> net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool2')
90
- ...
70
+ >>> net = tl.layers.Input([50, 256, 256, 3])
71
+ >>> layer = tl.layers.QuanConv2dWithBN(n_filter=64, filter_size=(5,5),strides=(1,1),padding='SAME',name='qcnnbn1')
72
+ >>> print(layer)
73
+ >>> net = tl.layers.QuanConv2dWithBN(n_filter=64, filter_size=(5,5),strides=(1,1),padding='SAME',name='qcnnbn1')(net)
74
+ >>> print(net)
91
75
"""
92
76
93
- @deprecated_alias (layer = 'prev_layer' , end_support_version = 1.9 ) # TODO remove this line for the 1.9 release
94
77
def __init__ (
95
78
self ,
96
- prev_layer ,
97
79
n_filter = 32 ,
98
80
filter_size = (3 , 3 ),
99
81
strides = (1 , 1 ),
@@ -102,125 +84,150 @@ def __init__(
102
84
decay = 0.9 ,
103
85
epsilon = 1e-5 ,
104
86
is_train = False ,
105
- gamma_init = tf . compat . v1 . initializers .ones ,
106
- beta_init = tf . compat . v1 . initializers .zeros ,
87
+ gamma_init = tl . initializers .truncated_normal ( stddev = 0.02 ) ,
88
+ beta_init = tl . initializers .truncated_normal ( stddev = 0.02 ) ,
107
89
bitW = 8 ,
108
90
bitA = 8 ,
109
91
use_gemm = False ,
110
- W_init = tf . compat . v1 .initializers .truncated_normal (stddev = 0.02 ),
92
+ W_init = tl .initializers .truncated_normal (stddev = 0.02 ),
111
93
W_init_args = None ,
112
- use_cudnn_on_gpu = None ,
113
- data_format = None ,
94
+ data_format = "channels_last" ,
95
+ dilation_rate = (1 , 1 ),
96
+ in_channels = None ,
114
97
name = 'quan_cnn2d_bn' ,
115
98
):
116
- super (QuanConv2dWithBN , self ).__init__ (prev_layer = prev_layer , act = act , W_init_args = W_init_args , name = name )
117
-
99
+ super (QuanConv2dWithBN , self ).__init__ (act = act , name = name )
100
+ self .n_filter = n_filter
101
+ self .filter_size = filter_size
102
+ self .strides = strides
103
+ self .padding = padding
104
+ self .decay = decay
105
+ self .epsilon = epsilon
106
+ self .is_train = is_train
107
+ self .gamma_init = gamma_init
108
+ self .beta_init = beta_init
109
+ self .bitW = bitW
110
+ self .bitA = bitA
111
+ self .use_gemm = use_gemm
112
+ self .W_init = W_init
113
+ self .W_init_args = W_init_args
114
+ self .data_format = data_format
115
+ self .dilation_rate = dilation_rate
116
+ self .in_channels = in_channels
118
117
logging .info (
119
118
"QuanConv2dWithBN %s: n_filter: %d filter_size: %s strides: %s pad: %s act: %s " % (
120
119
self .name , n_filter , filter_size , str (strides ), padding ,
121
120
self .act .__name__ if self .act is not None else 'No Activation'
122
121
)
123
122
)
124
123
125
- x = self .inputs
126
- self .inputs = quantize_active_overflow (self .inputs , bitA ) # Do not remove
124
+ if self .in_channels :
125
+ self .build (None )
126
+ self ._built = True
127
127
128
128
if use_gemm :
129
129
raise Exception ("TODO. The current version use tf.matmul for inferencing." )
130
130
131
131
if len (strides ) != 2 :
132
132
raise ValueError ("len(strides) should be 2." )
133
133
134
- try :
135
- pre_channel = int (prev_layer .outputs .get_shape ()[- 1 ])
136
- except Exception : # if pre_channel is ?, it happens when using Spatial Transformer Net
137
- pre_channel = 1
138
- logging .warning ("[warnings] unknow input channels, set to 1" )
139
-
140
- shape = (filter_size [0 ], filter_size [1 ], pre_channel , n_filter )
141
- strides = (1 , strides [0 ], strides [1 ], 1 )
142
-
143
- with tf .compat .v1 .variable_scope (name ):
144
- W = tf .compat .v1 .get_variable (
145
- name = 'W_conv2d' , shape = shape , initializer = W_init , dtype = LayersConfig .tf_dtype , ** self .W_init_args
146
- )
147
-
148
- conv = tf .nn .conv2d (
149
- x , W , strides = strides , padding = padding , use_cudnn_on_gpu = use_cudnn_on_gpu , data_format = data_format
150
- )
151
-
152
- para_bn_shape = conv .get_shape ()[- 1 :]
153
-
154
- if gamma_init :
155
- scale_para = tf .compat .v1 .get_variable (
156
- name = 'scale_para' , shape = para_bn_shape , initializer = gamma_init , dtype = LayersConfig .tf_dtype ,
157
- trainable = is_train
158
- )
159
- else :
160
- scale_para = None
161
-
162
- if beta_init :
163
- offset_para = tf .compat .v1 .get_variable (
164
- name = 'offset_para' , shape = para_bn_shape , initializer = beta_init , dtype = LayersConfig .tf_dtype ,
165
- trainable = is_train
166
- )
167
- else :
168
- offset_para = None
169
-
170
- moving_mean = tf .compat .v1 .get_variable (
171
- 'moving_mean' , para_bn_shape , initializer = tf .compat .v1 .initializers .constant (1. ),
172
- dtype = LayersConfig .tf_dtype , trainable = False
134
+ def __repr__ (self ):
135
+ actstr = self .act .__name__ if self .act is not None else 'No Activation'
136
+ s = (
137
+ '{classname}(in_channels={in_channels}, out_channels={n_filter}, kernel_size={filter_size}'
138
+ ', strides={strides}, padding={padding}' + actstr
139
+ )
140
+ if self .dilation_rate != (1 , ) * len (self .dilation_rate ):
141
+ s += ', dilation={dilation_rate}'
142
+ if self .name is not None :
143
+ s += ', name=\' {name}\' '
144
+ s += ')'
145
+ return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
146
+
147
+ def build (self , inputs_shape ):
148
+ if self .data_format == 'channels_last' :
149
+ self .data_format = 'NHWC'
150
+ if self .in_channels is None :
151
+ self .in_channels = inputs_shape [- 1 ]
152
+ self ._strides = [1 , self .strides [0 ], self .strides [1 ], 1 ]
153
+ self ._dilation_rate = [1 , self .dilation_rate [0 ], self .dilation_rate [1 ], 1 ]
154
+ elif self .data_format == 'channels_first' :
155
+ self .data_format = 'NCHW'
156
+ if self .in_channels is None :
157
+ self .in_channels = inputs_shape [1 ]
158
+ self ._strides = [1 , 1 , self .strides [0 ], self .strides [1 ]]
159
+ self ._dilation_rate = [1 , 1 , self .dilation_rate [0 ], self .dilation_rate [1 ]]
160
+ else :
161
+ raise Exception ("data_format should be either channels_last or channels_first" )
162
+
163
+ self .filter_shape = (self .filter_size [0 ], self .filter_size [1 ], self .in_channels , self .n_filter )
164
+ self .W = self ._get_weights ("filters" , shape = self .filter_shape , init = self .W_init )
165
+
166
+ para_bn_shape = (self .n_filter , )
167
+ if self .gamma_init :
168
+ self .scale_para = self ._get_weights (
169
+ "scale_para" , shape = para_bn_shape , init = self .gamma_init , trainable = self .is_train
173
170
)
171
+ else :
172
+ self .scale_para = None
174
173
175
- moving_variance = tf .compat .v1 .get_variable (
176
- 'moving_variance' ,
177
- para_bn_shape ,
178
- initializer = tf .compat .v1 .initializers .constant (1. ),
179
- dtype = LayersConfig .tf_dtype ,
180
- trainable = False ,
174
+ if self .beta_init :
175
+ self .offset_para = self ._get_weights (
176
+ "offset_para" , shape = para_bn_shape , init = self .beta_init , trainable = self .is_train
181
177
)
178
+ else :
179
+ self .offset_para = None
182
180
183
- mean , variance = tf .nn .moments (x = conv , axes = list (range (len (conv .get_shape ()) - 1 )))
184
-
185
- update_moving_mean = moving_averages .assign_moving_average (
186
- moving_mean , mean , decay , zero_debias = False
187
- ) # if zero_debias=True, has bias
188
-
189
- update_moving_variance = moving_averages .assign_moving_average (
190
- moving_variance , variance , decay , zero_debias = False
191
- ) # if zero_debias=True, has bias
181
+ self .moving_mean = self ._get_weights (
182
+ "moving_mean" , shape = para_bn_shape , init = tl .initializers .constant (1.0 ), trainable = False
183
+ )
184
+ self .moving_variance = self ._get_weights (
185
+ "moving_variance" , shape = para_bn_shape , init = tl .initializers .constant (1.0 ), trainable = False
186
+ )
192
187
193
- def mean_var_with_update ():
194
- with tf .control_dependencies ([update_moving_mean , update_moving_variance ]):
195
- return tf .identity (mean ), tf .identity (variance )
188
+ def forward (self , inputs ):
189
+ x = inputs
190
+ inputs = quantize_active_overflow (inputs , self .bitA ) # Do not remove
191
+ outputs = tf .nn .conv2d (
192
+ input = x , filters = self .W , strides = self ._strides , padding = self .padding , data_format = self .data_format ,
193
+ dilations = self ._dilation_rate , name = self .name
194
+ )
196
195
197
- if is_train :
198
- mean , var = mean_var_with_update ()
199
- else :
200
- mean , var = moving_mean , moving_variance
196
+ mean , variance = tf .nn .moments (outputs , axes = list (range (len (outputs .get_shape ()) - 1 )))
201
197
202
- w_fold = _w_fold (W , scale_para , var , epsilon )
203
- bias_fold = _bias_fold (offset_para , scale_para , mean , var , epsilon )
198
+ update_moving_mean = moving_averages .assign_moving_average (
199
+ self .moving_mean , mean , self .decay , zero_debias = False
200
+ ) # if zero_debias=True, has bias
201
+ update_moving_variance = moving_averages .assign_moving_average (
202
+ self .moving_variance , mean , self .decay , zero_debias = False
203
+ ) # if zero_debias=True, has bias
204
204
205
- W = quantize_weight_overflow (w_fold , bitW )
205
+ if self .is_train :
206
+ mean , var = self .mean_var_with_update (update_moving_mean , update_moving_variance , mean , variance )
207
+ else :
208
+ mean , var = self .moving_mean , self .moving_variance
206
209
207
- conv_fold = tf .nn .conv2d (
208
- self .inputs , W , strides = strides , padding = padding , use_cudnn_on_gpu = use_cudnn_on_gpu ,
209
- data_format = data_format
210
- )
210
+ w_fold = self ._w_fold (self .W , self .scale_para , var , self .epsilon )
211
211
212
- self . outputs = tf . nn . bias_add ( conv_fold , bias_fold , name = 'bn_bias_add' )
212
+ W_ = quantize_weight_overflow ( w_fold , self . bitW )
213
213
214
- self .outputs = self ._apply_activation ( self .outputs )
214
+ conv_fold = tf . nn . conv2d ( inputs , W_ , strides = self .strides , padding = self .padding , data_format = self .data_format )
215
215
216
- self ._add_layers (self .outputs )
216
+ if self .beta_init :
217
+ bias_fold = self ._bias_fold (self .offset_para , self .scale_para , mean , var , self .epsilon )
218
+ conv_fold = tf .nn .bias_add (conv_fold , bias_fold , name = 'bn_bias_add' )
217
219
218
- self ._add_params ([W , scale_para , offset_para , moving_mean , moving_variance ])
220
+ if self .act :
221
+ conv_fold = self .act (conv_fold )
219
222
223
+ return conv_fold
220
224
221
- def _w_fold (w , gama , var , epsilon ):
222
- return tf .compat .v1 .div (tf .multiply (gama , w ), tf .sqrt (var + epsilon ))
225
+ def mean_var_with_update (self , update_moving_mean , update_moving_variance , mean , variance ):
226
+ with tf .control_dependencies ([update_moving_mean , update_moving_variance ]):
227
+ return tf .identity (mean ), tf .identity (variance )
223
228
229
+ def _w_fold (self , w , gama , var , epsilon ):
230
+ return tf .compat .v1 .div (tf .multiply (gama , w ), tf .sqrt (var + epsilon ))
224
231
225
- def _bias_fold (beta , gama , mean , var , epsilon ):
226
- return tf .subtract (beta , tf .compat .v1 .div (tf .multiply (gama , mean ), tf .sqrt (var + epsilon )))
232
+ def _bias_fold (self , beta , gama , mean , var , epsilon ):
233
+ return tf .subtract (beta , tf .compat .v1 .div (tf .multiply (gama , mean ), tf .sqrt (var + epsilon )))
0 commit comments