@@ -98,19 +98,13 @@ def __init__(self,
9898 self ._norm_epsilon = norm_epsilon
9999 self ._kernel_regularizer = kernel_regularizer
100100 self ._bias_regularizer = bias_regularizer
101- if use_sync_bn :
102- self ._norm = helper .quantize_wrapped_layer (
103- tf .keras .layers .experimental .SyncBatchNormalization ,
104- configs .NoOpQuantizeConfig ())
105- self ._norm_with_quantize = helper .quantize_wrapped_layer (
106- tf .keras .layers .experimental .SyncBatchNormalization ,
107- configs .Default8BitOutputQuantizeConfig ())
108- else :
109- self ._norm = helper .quantize_wrapped_layer (
110- tf .keras .layers .BatchNormalization , configs .NoOpQuantizeConfig ())
111- self ._norm_with_quantize = helper .quantize_wrapped_layer (
112- tf .keras .layers .BatchNormalization ,
113- configs .Default8BitOutputQuantizeConfig ())
101+
102+ norm_layer = (
103+ tf .keras .layers .experimental .SyncBatchNormalization
104+ if use_sync_bn else tf .keras .layers .BatchNormalization )
105+ self ._norm_with_quantize = helper .BatchNormalizationQuantized (norm_layer )
106+ self ._norm = helper .BatchNormalizationNoQuantized (norm_layer )
107+
114108 if tf .keras .backend .image_data_format () == 'channels_last' :
115109 self ._bn_axis = - 1
116110 else :
@@ -119,15 +113,11 @@ def __init__(self,
119113
120114 def build (self , input_shape : Optional [Union [Sequence [int ], tf .Tensor ]]):
121115 """Build variables and child layers to prepare for calling."""
122- conv2d_quantized = helper .quantize_wrapped_layer (
123- tf .keras .layers .Conv2D ,
124- configs .Default8BitConvQuantizeConfig (['kernel' ], ['activation' ],
125- False ))
126116 if self ._use_projection :
127117 if self ._resnetd_shortcut :
128118 self ._shortcut0 = tf .keras .layers .AveragePooling2D (
129119 pool_size = 2 , strides = self ._strides , padding = 'same' )
130- self ._shortcut1 = conv2d_quantized (
120+ self ._shortcut1 = helper . Conv2DQuantized (
131121 filters = self ._filters * 4 ,
132122 kernel_size = 1 ,
133123 strides = 1 ,
@@ -137,7 +127,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
137127 bias_regularizer = self ._bias_regularizer ,
138128 activation = helper .NoOpActivation ())
139129 else :
140- self ._shortcut = conv2d_quantized (
130+ self ._shortcut = helper . Conv2DQuantized (
141131 filters = self ._filters * 4 ,
142132 kernel_size = 1 ,
143133 strides = self ._strides ,
@@ -153,7 +143,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
153143 epsilon = self ._norm_epsilon ,
154144 trainable = self ._bn_trainable )
155145
156- self ._conv1 = conv2d_quantized (
146+ self ._conv1 = helper . Conv2DQuantized (
157147 filters = self ._filters ,
158148 kernel_size = 1 ,
159149 strides = 1 ,
@@ -171,7 +161,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
171161 tf_utils .get_activation (self ._activation , use_keras_layer = True ),
172162 configs .Default8BitActivationQuantizeConfig ())
173163
174- self ._conv2 = conv2d_quantized (
164+ self ._conv2 = helper . Conv2DQuantized (
175165 filters = self ._filters ,
176166 kernel_size = 3 ,
177167 strides = self ._strides ,
@@ -191,7 +181,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
191181 tf_utils .get_activation (self ._activation , use_keras_layer = True ),
192182 configs .Default8BitActivationQuantizeConfig ())
193183
194- self ._conv3 = conv2d_quantized (
184+ self ._conv3 = helper . Conv2DQuantized (
195185 filters = self ._filters * 4 ,
196186 kernel_size = 1 ,
197187 strides = 1 ,
@@ -359,10 +349,8 @@ def __init__(
359349 norm_layer = (
360350 tf .keras .layers .experimental .SyncBatchNormalization
361351 if use_sync_bn else tf .keras .layers .BatchNormalization )
362- self ._norm_with_quantize = helper .quantize_wrapped_layer (
363- norm_layer , configs .Default8BitOutputQuantizeConfig ())
364- self ._norm = helper .quantize_wrapped_layer (norm_layer ,
365- configs .NoOpQuantizeConfig ())
352+ self ._norm_with_quantize = helper .BatchNormalizationQuantized (norm_layer )
353+ self ._norm = helper .BatchNormalizationNoQuantized (norm_layer )
366354
367355 if tf .keras .backend .image_data_format () == 'channels_last' :
368356 self ._bn_axis = - 1
@@ -389,20 +377,15 @@ def get_config(self) -> Dict[str, Any]:
389377 base_config = super (Conv2DBNBlockQuantized , self ).get_config ()
390378 return dict (list (base_config .items ()) + list (config .items ()))
391379
392- def _norm_by_activation (self , activation ):
393- if activation in ['relu' , 'relu6' ]:
394- return self ._norm
395- return self ._norm_with_quantize
396-
397380 def build (self , input_shape : Optional [Union [Sequence [int ], tf .Tensor ]]):
398381 """Build variables and child layers to prepare for calling."""
399382 if self ._use_explicit_padding and self ._kernel_size > 1 :
400383 padding_size = nn_layers .get_padding_for_kernel_size (self ._kernel_size )
401384 self ._pad = tf .keras .layers .ZeroPadding2D (padding_size )
402- conv2d_quantized = helper . quantize_wrapped_layer (
403- tf . keras . layers . Conv2D ,
404- configs . Default8BitConvQuantizeConfig ([ 'kernel' ], [ 'activation' ],
405- not self . _use_normalization ))
385+ conv2d_quantized = (
386+ helper . Conv2DQuantized
387+ if self . _use_normalization else helper . Conv2DOutputQuantized )
388+
406389 self ._conv0 = conv2d_quantized (
407390 filters = self ._filters ,
408391 kernel_size = self ._kernel_size ,
@@ -414,14 +397,15 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
414397 bias_regularizer = self ._bias_regularizer ,
415398 activation = helper .NoOpActivation ())
416399 if self ._use_normalization :
417- self ._norm0 = self ._norm_by_activation (self ._activation )(
418- axis = self ._bn_axis ,
419- momentum = self ._norm_momentum ,
420- epsilon = self ._norm_epsilon )
400+ self ._norm0 = helper .norm_by_activation (self ._activation ,
401+ self ._norm_with_quantize ,
402+ self ._norm )(
403+ axis = self ._bn_axis ,
404+ momentum = self ._norm_momentum ,
405+ epsilon = self ._norm_epsilon )
421406 self ._activation_layer = tfmot .quantization .keras .QuantizeWrapperV2 (
422407 tf_utils .get_activation (self ._activation , use_keras_layer = True ),
423408 configs .Default8BitActivationQuantizeConfig ())
424-
425409 super (Conv2DBNBlockQuantized , self ).build (input_shape )
426410
427411 def call (
@@ -546,10 +530,8 @@ def __init__(self,
546530 norm_layer = (
547531 tf .keras .layers .experimental .SyncBatchNormalization
548532 if use_sync_bn else tf .keras .layers .BatchNormalization )
549- self ._norm_with_quantize = helper .quantize_wrapped_layer (
550- norm_layer , configs .Default8BitOutputQuantizeConfig ())
551- self ._norm = helper .quantize_wrapped_layer (norm_layer ,
552- configs .NoOpQuantizeConfig ())
533+ self ._norm_with_quantize = helper .BatchNormalizationQuantized (norm_layer )
534+ self ._norm = helper .BatchNormalizationNoQuantized (norm_layer )
553535
554536 if tf .keras .backend .image_data_format () == 'channels_last' :
555537 self ._bn_axis = - 1
@@ -562,21 +544,8 @@ def __init__(self,
562544 else :
563545 self ._depthsize_regularizer = None
564546
565- def _norm_by_activation (self , activation ):
566- if activation in ['relu' , 'relu6' ]:
567- return self ._norm
568- return self ._norm_with_quantize
569-
570547 def build (self , input_shape : Optional [Union [Sequence [int ], tf .Tensor ]]):
571548 """Build variables and child layers to prepare for calling."""
572- conv2d_quantized = helper .quantize_wrapped_layer (
573- tf .keras .layers .Conv2D ,
574- configs .Default8BitConvQuantizeConfig (['kernel' ], ['activation' ],
575- False ))
576- depthwise_conv2d_quantized = helper .quantize_wrapped_layer (
577- tf .keras .layers .DepthwiseConv2D ,
578- configs .Default8BitConvQuantizeConfig (['depthwise_kernel' ],
579- ['activation' ], False ))
580549 expand_filters = self ._in_filters
581550 if self ._expand_ratio > 1 :
582551 # First 1x1 conv for channel expansion.
@@ -586,7 +555,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
586555 expand_kernel = 1 if self ._use_depthwise else self ._kernel_size
587556 expand_stride = 1 if self ._use_depthwise else self ._strides
588557
589- self ._conv0 = conv2d_quantized (
558+ self ._conv0 = helper . Conv2DQuantized (
590559 filters = expand_filters ,
591560 kernel_size = expand_kernel ,
592561 strides = expand_stride ,
@@ -596,17 +565,18 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
596565 kernel_regularizer = self ._kernel_regularizer ,
597566 bias_regularizer = self ._bias_regularizer ,
598567 activation = helper .NoOpActivation ())
599- self ._norm0 = self ._norm_by_activation (self ._activation )(
600- axis = self ._bn_axis ,
601- momentum = self ._norm_momentum ,
602- epsilon = self ._norm_epsilon )
568+ self ._norm0 = helper .norm_by_activation (self ._activation ,
569+ self ._norm_with_quantize ,
570+ self ._norm )(
571+ axis = self ._bn_axis ,
572+ momentum = self ._norm_momentum ,
573+ epsilon = self ._norm_epsilon )
603574 self ._activation_layer = tfmot .quantization .keras .QuantizeWrapperV2 (
604575 tf_utils .get_activation (self ._activation , use_keras_layer = True ),
605576 configs .Default8BitActivationQuantizeConfig ())
606-
607577 if self ._use_depthwise :
608578 # Depthwise conv.
609- self ._conv1 = depthwise_conv2d_quantized (
579+ self ._conv1 = helper . DepthwiseConv2DQuantized (
610580 kernel_size = (self ._kernel_size , self ._kernel_size ),
611581 strides = self ._strides ,
612582 padding = 'same' ,
@@ -617,10 +587,12 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
617587 depthwise_regularizer = self ._depthsize_regularizer ,
618588 bias_regularizer = self ._bias_regularizer ,
619589 activation = helper .NoOpActivation ())
620- self ._norm1 = self ._norm_by_activation (self ._depthwise_activation )(
621- axis = self ._bn_axis ,
622- momentum = self ._norm_momentum ,
623- epsilon = self ._norm_epsilon )
590+ self ._norm1 = helper .norm_by_activation (self ._depthwise_activation ,
591+ self ._norm_with_quantize ,
592+ self ._norm )(
593+ axis = self ._bn_axis ,
594+ momentum = self ._norm_momentum ,
595+ epsilon = self ._norm_epsilon )
624596 self ._depthwise_activation_layer = (
625597 tfmot .quantization .keras .QuantizeWrapperV2 (
626598 tf_utils .get_activation (self ._depthwise_activation ,
@@ -648,7 +620,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
648620 self ._squeeze_excitation = None
649621
650622 # Last 1x1 conv.
651- self ._conv2 = conv2d_quantized (
623+ self ._conv2 = helper . Conv2DQuantized (
652624 filters = self ._out_filters ,
653625 kernel_size = 1 ,
654626 strides = 1 ,
0 commit comments