Skip to content

Commit d5c790a

Browse files
committed
add unit testing and fix quan_dense_bn
1 parent 2f90939 commit d5c790a

File tree

6 files changed

+105
-16
lines changed

6 files changed

+105
-16
lines changed

tensorlayer/layers/convolution/quan_conv_bn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ def __init__(
121121
)
122122
)
123123

124+
if self.in_channels:
125+
self.build(None)
126+
self._built = True
127+
124128
if use_gemm:
125129
raise Exception("TODO. The current version use tf.matmul for inferencing.")
126130

tensorlayer/layers/dense/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,5 @@
2424
'DropconnectDense',
2525
'TernaryDense',
2626
'QuanDense',
27-
'QuanDenseLayerWithBN',
27+
'QuanDenseWithBN',
2828
]

tensorlayer/layers/dense/quan_dense_bn.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
from tensorlayer.layers.utils import (quantize_active_overflow, quantize_weight_overflow)
1212

1313
__all__ = [
14-
'QuanDenseLayerWithBN',
14+
'QuanDenseWithBN',
1515
]
1616

1717

18-
class QuanDenseLayerWithBN(Layer):
19-
"""The :class:`QuanDenseLayerWithBN` class is a quantized fully connected layer with BN, which weights are 'bitW' bits and the output of the previous layer
18+
class QuanDenseWithBN(Layer):
19+
"""The :class:`QuanDenseWithBN` class is a quantized fully connected layer with BN, which weights are 'bitW' bits and the output of the previous layer
2020
are 'bitA' bits while inferencing.
2121
2222
Parameters
@@ -47,16 +47,19 @@ class QuanDenseLayerWithBN(Layer):
4747
The initializer for the the weight matrix.
4848
W_init_args : dictionary
4949
The arguments for the weight matrix initializer.
50+
in_channels: int
51+
The number of channels of the previous layer.
52+
If None, it will be automatically detected when the layer is forwarded for the first time.
5053
name : a str
5154
A unique layer name.
5255
5356
Examples
5457
---------
5558
>>> import tensorlayer as tl
5659
>>> net = tl.layers.Input([50, 256])
57-
>>> layer = tl.layers.QuanDenseLayerWithBN(128, act='relu', name='qdbn1')(net)
60+
>>> layer = tl.layers.QuanDenseWithBN(128, act='relu', name='qdbn1')(net)
5861
>>> print(layer)
59-
>>> net = tl.layers.QuanDenseLayerWithBN(256, act='relu', name='qdbn2')(net)
62+
>>> net = tl.layers.QuanDenseWithBN(256, act='relu', name='qdbn2')(net)
6063
>>> print(net)
6164
"""
6265

@@ -74,9 +77,10 @@ def __init__(
7477
use_gemm=False,
7578
W_init=tl.initializers.truncated_normal(stddev=0.05),
7679
W_init_args=None,
80+
in_channels=None,
7781
name=None, # 'quan_dense_with_bn',
7882
):
79-
super(QuanDenseLayerWithBN, self).__init__(act=act, W_init_args=W_init_args, name=name)
83+
super(QuanDenseWithBN, self).__init__(act=act, W_init_args=W_init_args, name=name)
8084
self.n_units = n_units
8185
self.decay = decay
8286
self.epsilon = epsilon
@@ -87,6 +91,11 @@ def __init__(
8791
self.beta_init = beta_init
8892
self.use_gemm = use_gemm
8993
self.W_init = W_init
94+
self.in_channels = in_channels
95+
96+
if self.in_channels is not None:
97+
self.build((None, self.in_channels))
98+
self._built = True
9099

91100
logging.info(
92101
"QuanDenseLayerWithBN %s: %d %s" %
@@ -105,9 +114,12 @@ def __repr__(self):
105114
return s.format(classname=self.__class__.__name__, **self.__dict__)
106115

107116
def build(self, inputs_shape):
108-
if len(inputs_shape) != 2:
117+
if self.in_channels is None and len(inputs_shape) != 2:
109118
raise Exception("The input dimension must be rank 2, please reshape or flatten it")
110119

120+
if self.in_channels is None:
121+
self.in_channels = inputs_shape[1]
122+
111123
if self.use_gemm:
112124
raise Exception("TODO. The current version use tf.matmul for inferencing.")
113125

tensorlayer/layers/deprecated.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ def DropconnectDenseLayer(*args, **kwargs):
9191
raise NonExistingLayerError("DropconnectDenseLayer(net, name='a') --> DropconnectDense(name='a')(net)" + __log__)
9292

9393

94+
# dense/quan_dense_bn.py
95+
__all__ += [
96+
'QuanDenseLayerWithBN',
97+
]
98+
99+
100+
def QuanDenseLayerWithBN(*args, **kwargs):
101+
raise NonExistingLayerError("QuanDenseLayerWithBN(net, name='a') --> QuanDenseWithBN(name='a')(net)" + __log__)
102+
103+
94104
# dense/ternary_dense.py
95105
__all__ += [
96106
'TernaryDenseLayer',

tests/layers/test_layers_convolution.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import os
55
import unittest
66

7-
import tensorflow as tf
7+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
88

9+
import tensorflow as tf
910
import tensorlayer as tl
1011
from tensorlayer.layers import *
1112
from tensorlayer.models import *
12-
from tests.utils import CustomTestCase
1313

14-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
14+
from tests.utils import CustomTestCase
1515

1616

1717
class Layer_Convolution_1D_Test(CustomTestCase):
@@ -208,7 +208,11 @@ def setUpClass(cls):
208208

209209
cls.n14 = tl.layers.SubpixelConv2d(scale=2, act=tf.nn.relu, name='subpixelconv2d')(cls.n13)
210210

211-
cls.model = Model(cls.input_layer, cls.n14)
211+
cls.n15 = tl.layers.QuanConv2dWithBN(
212+
n_filter=64, filter_size=(5, 5), strides=(1, 1), act=tf.nn.relu, padding='SAME', name='quancnnbn2d'
213+
)(cls.n14)
214+
215+
cls.model = Model(cls.input_layer, cls.n15)
212216
print("Testing Conv2d model: \n", cls.model)
213217

214218
# cls.n12 = tl.layers.QuanConv2d(cls.n11, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', name='quancnn')
@@ -321,6 +325,10 @@ def test_layer_n13(self):
321325
def test_layer_n14(self):
322326
self.assertEqual(self.n14.get_shape().as_list()[1:], [24, 24, 8])
323327

328+
def test_layer_n15(self):
329+
self.assertEqual(len(self.n15._info[0].layer.all_weights), 5)
330+
self.assertEqual(self.n15.get_shape().as_list()[1:], [24, 24, 64])
331+
324332
# def test_layer_n8(self):
325333
#
326334
# self.assertEqual(len(self.n8.all_layers), 9)

tests/layers/test_layers_dense.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import os
44
import unittest
55

6-
import numpy as np
7-
import tensorflow as tf
6+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
87

8+
import tensorflow as tf
99
import tensorlayer as tl
1010
from tensorlayer.layers import *
1111
from tensorlayer.models import *
12-
from tests.utils import CustomTestCase
1312

14-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
13+
from tests.utils import CustomTestCase
14+
import numpy as np
1515

1616

1717
class Layer_BinaryDense_Test(CustomTestCase):
@@ -243,6 +243,61 @@ def test_exception(self):
243243
print(e)
244244

245245

246+
class Layer_QuanDenseWithBN_Test(CustomTestCase):
247+
248+
@classmethod
249+
def setUpClass(cls):
250+
print("-" * 20, "Layer_QuanDenseWithBN_Test", "-" * 20)
251+
cls.batch_size = 4
252+
cls.inputs_shape = [cls.batch_size, 10]
253+
254+
cls.ni = Input(cls.inputs_shape, name='input_layer')
255+
cls.layer1 = QuanDenseWithBN(n_units=5)
256+
nn = cls.layer1(cls.ni)
257+
cls.layer1._nodes_fixed = True
258+
cls.M = Model(inputs=cls.ni, outputs=nn)
259+
260+
cls.layer2 = QuanDenseWithBN(n_units=5, in_channels=10)
261+
cls.layer2._nodes_fixed = True
262+
263+
cls.inputs = tf.random.uniform((cls.inputs_shape))
264+
cls.n1 = cls.layer1(cls.inputs)
265+
cls.n2 = cls.layer2(cls.inputs)
266+
cls.n3 = cls.M(cls.inputs, is_train=True)
267+
268+
print(cls.layer1)
269+
print(cls.layer2)
270+
271+
@classmethod
272+
def tearDownClass(cls):
273+
pass
274+
275+
def test_layer_n1(self):
276+
print(self.n1[0])
277+
278+
def test_layer_n2(self):
279+
print(self.n2[0])
280+
281+
def test_model_n3(self):
282+
print(self.n3[0])
283+
284+
def test_exception(self):
285+
try:
286+
layer = QuanDenseWithBN(n_units=5)
287+
inputs = Input([4, 10, 5], name='ill_inputs')
288+
out = layer(inputs)
289+
self.fail('ill inputs')
290+
except Exception as e:
291+
print(e)
292+
293+
try:
294+
layer = QuanDenseWithBN(n_units=5, use_gemm=True)
295+
out = layer(self.ni)
296+
self.fail('use gemm')
297+
except Exception as e:
298+
print(e)
299+
300+
246301
class Layer_TernaryDense_Test(CustomTestCase):
247302

248303
@classmethod

0 commit comments

Comments
 (0)