@@ -98,19 +98,13 @@ def __init__(self,
98
98
self ._norm_epsilon = norm_epsilon
99
99
self ._kernel_regularizer = kernel_regularizer
100
100
self ._bias_regularizer = bias_regularizer
101
- if use_sync_bn :
102
- self ._norm = helper .quantize_wrapped_layer (
103
- tf .keras .layers .experimental .SyncBatchNormalization ,
104
- configs .NoOpQuantizeConfig ())
105
- self ._norm_with_quantize = helper .quantize_wrapped_layer (
106
- tf .keras .layers .experimental .SyncBatchNormalization ,
107
- configs .Default8BitOutputQuantizeConfig ())
108
- else :
109
- self ._norm = helper .quantize_wrapped_layer (
110
- tf .keras .layers .BatchNormalization , configs .NoOpQuantizeConfig ())
111
- self ._norm_with_quantize = helper .quantize_wrapped_layer (
112
- tf .keras .layers .BatchNormalization ,
113
- configs .Default8BitOutputQuantizeConfig ())
101
+
102
+ norm_layer = (
103
+ tf .keras .layers .experimental .SyncBatchNormalization
104
+ if use_sync_bn else tf .keras .layers .BatchNormalization )
105
+ self ._norm_with_quantize = helper .BatchNormalizationQuantized (norm_layer )
106
+ self ._norm = helper .BatchNormalizationNoQuantized (norm_layer )
107
+
114
108
if tf .keras .backend .image_data_format () == 'channels_last' :
115
109
self ._bn_axis = - 1
116
110
else :
@@ -119,15 +113,11 @@ def __init__(self,
119
113
120
114
def build (self , input_shape : Optional [Union [Sequence [int ], tf .Tensor ]]):
121
115
"""Build variables and child layers to prepare for calling."""
122
- conv2d_quantized = helper .quantize_wrapped_layer (
123
- tf .keras .layers .Conv2D ,
124
- configs .Default8BitConvQuantizeConfig (['kernel' ], ['activation' ],
125
- False ))
126
116
if self ._use_projection :
127
117
if self ._resnetd_shortcut :
128
118
self ._shortcut0 = tf .keras .layers .AveragePooling2D (
129
119
pool_size = 2 , strides = self ._strides , padding = 'same' )
130
- self ._shortcut1 = conv2d_quantized (
120
+ self ._shortcut1 = helper . Conv2DQuantized (
131
121
filters = self ._filters * 4 ,
132
122
kernel_size = 1 ,
133
123
strides = 1 ,
@@ -137,7 +127,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
137
127
bias_regularizer = self ._bias_regularizer ,
138
128
activation = helper .NoOpActivation ())
139
129
else :
140
- self ._shortcut = conv2d_quantized (
130
+ self ._shortcut = helper . Conv2DQuantized (
141
131
filters = self ._filters * 4 ,
142
132
kernel_size = 1 ,
143
133
strides = self ._strides ,
@@ -153,7 +143,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
153
143
epsilon = self ._norm_epsilon ,
154
144
trainable = self ._bn_trainable )
155
145
156
- self ._conv1 = conv2d_quantized (
146
+ self ._conv1 = helper . Conv2DQuantized (
157
147
filters = self ._filters ,
158
148
kernel_size = 1 ,
159
149
strides = 1 ,
@@ -171,7 +161,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
171
161
tf_utils .get_activation (self ._activation , use_keras_layer = True ),
172
162
configs .Default8BitActivationQuantizeConfig ())
173
163
174
- self ._conv2 = conv2d_quantized (
164
+ self ._conv2 = helper . Conv2DQuantized (
175
165
filters = self ._filters ,
176
166
kernel_size = 3 ,
177
167
strides = self ._strides ,
@@ -191,7 +181,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
191
181
tf_utils .get_activation (self ._activation , use_keras_layer = True ),
192
182
configs .Default8BitActivationQuantizeConfig ())
193
183
194
- self ._conv3 = conv2d_quantized (
184
+ self ._conv3 = helper . Conv2DQuantized (
195
185
filters = self ._filters * 4 ,
196
186
kernel_size = 1 ,
197
187
strides = 1 ,
@@ -359,10 +349,8 @@ def __init__(
359
349
norm_layer = (
360
350
tf .keras .layers .experimental .SyncBatchNormalization
361
351
if use_sync_bn else tf .keras .layers .BatchNormalization )
362
- self ._norm_with_quantize = helper .quantize_wrapped_layer (
363
- norm_layer , configs .Default8BitOutputQuantizeConfig ())
364
- self ._norm = helper .quantize_wrapped_layer (norm_layer ,
365
- configs .NoOpQuantizeConfig ())
352
+ self ._norm_with_quantize = helper .BatchNormalizationQuantized (norm_layer )
353
+ self ._norm = helper .BatchNormalizationNoQuantized (norm_layer )
366
354
367
355
if tf .keras .backend .image_data_format () == 'channels_last' :
368
356
self ._bn_axis = - 1
@@ -389,20 +377,15 @@ def get_config(self) -> Dict[str, Any]:
389
377
base_config = super (Conv2DBNBlockQuantized , self ).get_config ()
390
378
return dict (list (base_config .items ()) + list (config .items ()))
391
379
392
- def _norm_by_activation (self , activation ):
393
- if activation in ['relu' , 'relu6' ]:
394
- return self ._norm
395
- return self ._norm_with_quantize
396
-
397
380
def build (self , input_shape : Optional [Union [Sequence [int ], tf .Tensor ]]):
398
381
"""Build variables and child layers to prepare for calling."""
399
382
if self ._use_explicit_padding and self ._kernel_size > 1 :
400
383
padding_size = nn_layers .get_padding_for_kernel_size (self ._kernel_size )
401
384
self ._pad = tf .keras .layers .ZeroPadding2D (padding_size )
402
- conv2d_quantized = helper . quantize_wrapped_layer (
403
- tf . keras . layers . Conv2D ,
404
- configs . Default8BitConvQuantizeConfig ([ 'kernel' ], [ 'activation' ],
405
- not self . _use_normalization ))
385
+ conv2d_quantized = (
386
+ helper . Conv2DQuantized
387
+ if self . _use_normalization else helper . Conv2DOutputQuantized )
388
+
406
389
self ._conv0 = conv2d_quantized (
407
390
filters = self ._filters ,
408
391
kernel_size = self ._kernel_size ,
@@ -414,14 +397,15 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
414
397
bias_regularizer = self ._bias_regularizer ,
415
398
activation = helper .NoOpActivation ())
416
399
if self ._use_normalization :
417
- self ._norm0 = self ._norm_by_activation (self ._activation )(
418
- axis = self ._bn_axis ,
419
- momentum = self ._norm_momentum ,
420
- epsilon = self ._norm_epsilon )
400
+ self ._norm0 = helper .norm_by_activation (self ._activation ,
401
+ self ._norm_with_quantize ,
402
+ self ._norm )(
403
+ axis = self ._bn_axis ,
404
+ momentum = self ._norm_momentum ,
405
+ epsilon = self ._norm_epsilon )
421
406
self ._activation_layer = tfmot .quantization .keras .QuantizeWrapperV2 (
422
407
tf_utils .get_activation (self ._activation , use_keras_layer = True ),
423
408
configs .Default8BitActivationQuantizeConfig ())
424
-
425
409
super (Conv2DBNBlockQuantized , self ).build (input_shape )
426
410
427
411
def call (
@@ -546,10 +530,8 @@ def __init__(self,
546
530
norm_layer = (
547
531
tf .keras .layers .experimental .SyncBatchNormalization
548
532
if use_sync_bn else tf .keras .layers .BatchNormalization )
549
- self ._norm_with_quantize = helper .quantize_wrapped_layer (
550
- norm_layer , configs .Default8BitOutputQuantizeConfig ())
551
- self ._norm = helper .quantize_wrapped_layer (norm_layer ,
552
- configs .NoOpQuantizeConfig ())
533
+ self ._norm_with_quantize = helper .BatchNormalizationQuantized (norm_layer )
534
+ self ._norm = helper .BatchNormalizationNoQuantized (norm_layer )
553
535
554
536
if tf .keras .backend .image_data_format () == 'channels_last' :
555
537
self ._bn_axis = - 1
@@ -562,21 +544,8 @@ def __init__(self,
562
544
else :
563
545
self ._depthsize_regularizer = None
564
546
565
- def _norm_by_activation (self , activation ):
566
- if activation in ['relu' , 'relu6' ]:
567
- return self ._norm
568
- return self ._norm_with_quantize
569
-
570
547
def build (self , input_shape : Optional [Union [Sequence [int ], tf .Tensor ]]):
571
548
"""Build variables and child layers to prepare for calling."""
572
- conv2d_quantized = helper .quantize_wrapped_layer (
573
- tf .keras .layers .Conv2D ,
574
- configs .Default8BitConvQuantizeConfig (['kernel' ], ['activation' ],
575
- False ))
576
- depthwise_conv2d_quantized = helper .quantize_wrapped_layer (
577
- tf .keras .layers .DepthwiseConv2D ,
578
- configs .Default8BitConvQuantizeConfig (['depthwise_kernel' ],
579
- ['activation' ], False ))
580
549
expand_filters = self ._in_filters
581
550
if self ._expand_ratio > 1 :
582
551
# First 1x1 conv for channel expansion.
@@ -586,7 +555,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
586
555
expand_kernel = 1 if self ._use_depthwise else self ._kernel_size
587
556
expand_stride = 1 if self ._use_depthwise else self ._strides
588
557
589
- self ._conv0 = conv2d_quantized (
558
+ self ._conv0 = helper . Conv2DQuantized (
590
559
filters = expand_filters ,
591
560
kernel_size = expand_kernel ,
592
561
strides = expand_stride ,
@@ -596,17 +565,18 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
596
565
kernel_regularizer = self ._kernel_regularizer ,
597
566
bias_regularizer = self ._bias_regularizer ,
598
567
activation = helper .NoOpActivation ())
599
- self ._norm0 = self ._norm_by_activation (self ._activation )(
600
- axis = self ._bn_axis ,
601
- momentum = self ._norm_momentum ,
602
- epsilon = self ._norm_epsilon )
568
+ self ._norm0 = helper .norm_by_activation (self ._activation ,
569
+ self ._norm_with_quantize ,
570
+ self ._norm )(
571
+ axis = self ._bn_axis ,
572
+ momentum = self ._norm_momentum ,
573
+ epsilon = self ._norm_epsilon )
603
574
self ._activation_layer = tfmot .quantization .keras .QuantizeWrapperV2 (
604
575
tf_utils .get_activation (self ._activation , use_keras_layer = True ),
605
576
configs .Default8BitActivationQuantizeConfig ())
606
-
607
577
if self ._use_depthwise :
608
578
# Depthwise conv.
609
- self ._conv1 = depthwise_conv2d_quantized (
579
+ self ._conv1 = helper . DepthwiseConv2DQuantized (
610
580
kernel_size = (self ._kernel_size , self ._kernel_size ),
611
581
strides = self ._strides ,
612
582
padding = 'same' ,
@@ -617,10 +587,12 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
617
587
depthwise_regularizer = self ._depthsize_regularizer ,
618
588
bias_regularizer = self ._bias_regularizer ,
619
589
activation = helper .NoOpActivation ())
620
- self ._norm1 = self ._norm_by_activation (self ._depthwise_activation )(
621
- axis = self ._bn_axis ,
622
- momentum = self ._norm_momentum ,
623
- epsilon = self ._norm_epsilon )
590
+ self ._norm1 = helper .norm_by_activation (self ._depthwise_activation ,
591
+ self ._norm_with_quantize ,
592
+ self ._norm )(
593
+ axis = self ._bn_axis ,
594
+ momentum = self ._norm_momentum ,
595
+ epsilon = self ._norm_epsilon )
624
596
self ._depthwise_activation_layer = (
625
597
tfmot .quantization .keras .QuantizeWrapperV2 (
626
598
tf_utils .get_activation (self ._depthwise_activation ,
@@ -648,7 +620,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
648
620
self ._squeeze_excitation = None
649
621
650
622
# Last 1x1 conv.
651
- self ._conv2 = conv2d_quantized (
623
+ self ._conv2 = helper . Conv2DQuantized (
652
624
filters = self ._out_filters ,
653
625
kernel_size = 1 ,
654
626
strides = 1 ,
0 commit comments