Skip to content

Commit 5036929

Browse files
fyangftensorflower-gardener
authored andcommitted
Internal change.
PiperOrigin-RevId: 437372581
1 parent 0876884 commit 5036929

File tree

3 files changed

+114
-150
lines changed

3 files changed

+114
-150
lines changed

official/projects/qat/vision/modeling/layers/nn_blocks.py

+41-69
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,13 @@ def __init__(self,
9898
self._norm_epsilon = norm_epsilon
9999
self._kernel_regularizer = kernel_regularizer
100100
self._bias_regularizer = bias_regularizer
101-
if use_sync_bn:
102-
self._norm = helper.quantize_wrapped_layer(
103-
tf.keras.layers.experimental.SyncBatchNormalization,
104-
configs.NoOpQuantizeConfig())
105-
self._norm_with_quantize = helper.quantize_wrapped_layer(
106-
tf.keras.layers.experimental.SyncBatchNormalization,
107-
configs.Default8BitOutputQuantizeConfig())
108-
else:
109-
self._norm = helper.quantize_wrapped_layer(
110-
tf.keras.layers.BatchNormalization, configs.NoOpQuantizeConfig())
111-
self._norm_with_quantize = helper.quantize_wrapped_layer(
112-
tf.keras.layers.BatchNormalization,
113-
configs.Default8BitOutputQuantizeConfig())
101+
102+
norm_layer = (
103+
tf.keras.layers.experimental.SyncBatchNormalization
104+
if use_sync_bn else tf.keras.layers.BatchNormalization)
105+
self._norm_with_quantize = helper.BatchNormalizationQuantized(norm_layer)
106+
self._norm = helper.BatchNormalizationNoQuantized(norm_layer)
107+
114108
if tf.keras.backend.image_data_format() == 'channels_last':
115109
self._bn_axis = -1
116110
else:
@@ -119,15 +113,11 @@ def __init__(self,
119113

120114
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
121115
"""Build variables and child layers to prepare for calling."""
122-
conv2d_quantized = helper.quantize_wrapped_layer(
123-
tf.keras.layers.Conv2D,
124-
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
125-
False))
126116
if self._use_projection:
127117
if self._resnetd_shortcut:
128118
self._shortcut0 = tf.keras.layers.AveragePooling2D(
129119
pool_size=2, strides=self._strides, padding='same')
130-
self._shortcut1 = conv2d_quantized(
120+
self._shortcut1 = helper.Conv2DQuantized(
131121
filters=self._filters * 4,
132122
kernel_size=1,
133123
strides=1,
@@ -137,7 +127,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
137127
bias_regularizer=self._bias_regularizer,
138128
activation=helper.NoOpActivation())
139129
else:
140-
self._shortcut = conv2d_quantized(
130+
self._shortcut = helper.Conv2DQuantized(
141131
filters=self._filters * 4,
142132
kernel_size=1,
143133
strides=self._strides,
@@ -153,7 +143,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
153143
epsilon=self._norm_epsilon,
154144
trainable=self._bn_trainable)
155145

156-
self._conv1 = conv2d_quantized(
146+
self._conv1 = helper.Conv2DQuantized(
157147
filters=self._filters,
158148
kernel_size=1,
159149
strides=1,
@@ -171,7 +161,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
171161
tf_utils.get_activation(self._activation, use_keras_layer=True),
172162
configs.Default8BitActivationQuantizeConfig())
173163

174-
self._conv2 = conv2d_quantized(
164+
self._conv2 = helper.Conv2DQuantized(
175165
filters=self._filters,
176166
kernel_size=3,
177167
strides=self._strides,
@@ -191,7 +181,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
191181
tf_utils.get_activation(self._activation, use_keras_layer=True),
192182
configs.Default8BitActivationQuantizeConfig())
193183

194-
self._conv3 = conv2d_quantized(
184+
self._conv3 = helper.Conv2DQuantized(
195185
filters=self._filters * 4,
196186
kernel_size=1,
197187
strides=1,
@@ -359,10 +349,8 @@ def __init__(
359349
norm_layer = (
360350
tf.keras.layers.experimental.SyncBatchNormalization
361351
if use_sync_bn else tf.keras.layers.BatchNormalization)
362-
self._norm_with_quantize = helper.quantize_wrapped_layer(
363-
norm_layer, configs.Default8BitOutputQuantizeConfig())
364-
self._norm = helper.quantize_wrapped_layer(norm_layer,
365-
configs.NoOpQuantizeConfig())
352+
self._norm_with_quantize = helper.BatchNormalizationQuantized(norm_layer)
353+
self._norm = helper.BatchNormalizationNoQuantized(norm_layer)
366354

367355
if tf.keras.backend.image_data_format() == 'channels_last':
368356
self._bn_axis = -1
@@ -389,20 +377,15 @@ def get_config(self) -> Dict[str, Any]:
389377
base_config = super(Conv2DBNBlockQuantized, self).get_config()
390378
return dict(list(base_config.items()) + list(config.items()))
391379

392-
def _norm_by_activation(self, activation):
393-
if activation in ['relu', 'relu6']:
394-
return self._norm
395-
return self._norm_with_quantize
396-
397380
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
398381
"""Build variables and child layers to prepare for calling."""
399382
if self._use_explicit_padding and self._kernel_size > 1:
400383
padding_size = nn_layers.get_padding_for_kernel_size(self._kernel_size)
401384
self._pad = tf.keras.layers.ZeroPadding2D(padding_size)
402-
conv2d_quantized = helper.quantize_wrapped_layer(
403-
tf.keras.layers.Conv2D,
404-
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
405-
not self._use_normalization))
385+
conv2d_quantized = (
386+
helper.Conv2DQuantized
387+
if self._use_normalization else helper.Conv2DOutputQuantized)
388+
406389
self._conv0 = conv2d_quantized(
407390
filters=self._filters,
408391
kernel_size=self._kernel_size,
@@ -414,14 +397,15 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
414397
bias_regularizer=self._bias_regularizer,
415398
activation=helper.NoOpActivation())
416399
if self._use_normalization:
417-
self._norm0 = self._norm_by_activation(self._activation)(
418-
axis=self._bn_axis,
419-
momentum=self._norm_momentum,
420-
epsilon=self._norm_epsilon)
400+
self._norm0 = helper.norm_by_activation(self._activation,
401+
self._norm_with_quantize,
402+
self._norm)(
403+
axis=self._bn_axis,
404+
momentum=self._norm_momentum,
405+
epsilon=self._norm_epsilon)
421406
self._activation_layer = tfmot.quantization.keras.QuantizeWrapperV2(
422407
tf_utils.get_activation(self._activation, use_keras_layer=True),
423408
configs.Default8BitActivationQuantizeConfig())
424-
425409
super(Conv2DBNBlockQuantized, self).build(input_shape)
426410

427411
def call(
@@ -546,10 +530,8 @@ def __init__(self,
546530
norm_layer = (
547531
tf.keras.layers.experimental.SyncBatchNormalization
548532
if use_sync_bn else tf.keras.layers.BatchNormalization)
549-
self._norm_with_quantize = helper.quantize_wrapped_layer(
550-
norm_layer, configs.Default8BitOutputQuantizeConfig())
551-
self._norm = helper.quantize_wrapped_layer(norm_layer,
552-
configs.NoOpQuantizeConfig())
533+
self._norm_with_quantize = helper.BatchNormalizationQuantized(norm_layer)
534+
self._norm = helper.BatchNormalizationNoQuantized(norm_layer)
553535

554536
if tf.keras.backend.image_data_format() == 'channels_last':
555537
self._bn_axis = -1
@@ -562,21 +544,8 @@ def __init__(self,
562544
else:
563545
self._depthsize_regularizer = None
564546

565-
def _norm_by_activation(self, activation):
566-
if activation in ['relu', 'relu6']:
567-
return self._norm
568-
return self._norm_with_quantize
569-
570547
def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
571548
"""Build variables and child layers to prepare for calling."""
572-
conv2d_quantized = helper.quantize_wrapped_layer(
573-
tf.keras.layers.Conv2D,
574-
configs.Default8BitConvQuantizeConfig(['kernel'], ['activation'],
575-
False))
576-
depthwise_conv2d_quantized = helper.quantize_wrapped_layer(
577-
tf.keras.layers.DepthwiseConv2D,
578-
configs.Default8BitConvQuantizeConfig(['depthwise_kernel'],
579-
['activation'], False))
580549
expand_filters = self._in_filters
581550
if self._expand_ratio > 1:
582551
# First 1x1 conv for channel expansion.
@@ -586,7 +555,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
586555
expand_kernel = 1 if self._use_depthwise else self._kernel_size
587556
expand_stride = 1 if self._use_depthwise else self._strides
588557

589-
self._conv0 = conv2d_quantized(
558+
self._conv0 = helper.Conv2DQuantized(
590559
filters=expand_filters,
591560
kernel_size=expand_kernel,
592561
strides=expand_stride,
@@ -596,17 +565,18 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
596565
kernel_regularizer=self._kernel_regularizer,
597566
bias_regularizer=self._bias_regularizer,
598567
activation=helper.NoOpActivation())
599-
self._norm0 = self._norm_by_activation(self._activation)(
600-
axis=self._bn_axis,
601-
momentum=self._norm_momentum,
602-
epsilon=self._norm_epsilon)
568+
self._norm0 = helper.norm_by_activation(self._activation,
569+
self._norm_with_quantize,
570+
self._norm)(
571+
axis=self._bn_axis,
572+
momentum=self._norm_momentum,
573+
epsilon=self._norm_epsilon)
603574
self._activation_layer = tfmot.quantization.keras.QuantizeWrapperV2(
604575
tf_utils.get_activation(self._activation, use_keras_layer=True),
605576
configs.Default8BitActivationQuantizeConfig())
606-
607577
if self._use_depthwise:
608578
# Depthwise conv.
609-
self._conv1 = depthwise_conv2d_quantized(
579+
self._conv1 = helper.DepthwiseConv2DQuantized(
610580
kernel_size=(self._kernel_size, self._kernel_size),
611581
strides=self._strides,
612582
padding='same',
@@ -617,10 +587,12 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
617587
depthwise_regularizer=self._depthsize_regularizer,
618588
bias_regularizer=self._bias_regularizer,
619589
activation=helper.NoOpActivation())
620-
self._norm1 = self._norm_by_activation(self._depthwise_activation)(
621-
axis=self._bn_axis,
622-
momentum=self._norm_momentum,
623-
epsilon=self._norm_epsilon)
590+
self._norm1 = helper.norm_by_activation(self._depthwise_activation,
591+
self._norm_with_quantize,
592+
self._norm)(
593+
axis=self._bn_axis,
594+
momentum=self._norm_momentum,
595+
epsilon=self._norm_epsilon)
624596
self._depthwise_activation_layer = (
625597
tfmot.quantization.keras.QuantizeWrapperV2(
626598
tf_utils.get_activation(self._depthwise_activation,
@@ -648,7 +620,7 @@ def build(self, input_shape: Optional[Union[Sequence[int], tf.Tensor]]):
648620
self._squeeze_excitation = None
649621

650622
# Last 1x1 conv.
651-
self._conv2 = conv2d_quantized(
623+
self._conv2 = helper.Conv2DQuantized(
652624
filters=self._out_filters,
653625
kernel_size=1,
654626
strides=1,

0 commit comments

Comments
 (0)